Skip to content

Commit a355236

Browse files
kiszkhvanhovell
authored andcommitted
[SPARK-23583][SQL] Invoke should support interpreted execution
## What changes were proposed in this pull request? This pr added interpreted execution for `Invoke`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Kazuaki Ishizaki <[email protected]> Closes apache#20797 from kiszk/SPARK-28583.
1 parent 5197562 commit a355236

File tree

3 files changed

+163
-6
lines changed

3 files changed

+163
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
2020
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.expressions.objects._
23-
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
23+
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData}
2424
import org.apache.spark.sql.types._
2525
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
2626

@@ -794,6 +794,52 @@ object ScalaReflection extends ScalaReflection {
794794
"interface", "long", "native", "new", "null", "package", "private", "protected", "public",
795795
"return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw",
796796
"throws", "transient", "true", "try", "void", "volatile", "while")
797+
798+
val typeJavaMapping = Map[DataType, Class[_]](
799+
BooleanType -> classOf[Boolean],
800+
ByteType -> classOf[Byte],
801+
ShortType -> classOf[Short],
802+
IntegerType -> classOf[Int],
803+
LongType -> classOf[Long],
804+
FloatType -> classOf[Float],
805+
DoubleType -> classOf[Double],
806+
StringType -> classOf[UTF8String],
807+
DateType -> classOf[DateType.InternalType],
808+
TimestampType -> classOf[TimestampType.InternalType],
809+
BinaryType -> classOf[BinaryType.InternalType],
810+
CalendarIntervalType -> classOf[CalendarInterval]
811+
)
812+
813+
val typeBoxedJavaMapping = Map[DataType, Class[_]](
814+
BooleanType -> classOf[java.lang.Boolean],
815+
ByteType -> classOf[java.lang.Byte],
816+
ShortType -> classOf[java.lang.Short],
817+
IntegerType -> classOf[java.lang.Integer],
818+
LongType -> classOf[java.lang.Long],
819+
FloatType -> classOf[java.lang.Float],
820+
DoubleType -> classOf[java.lang.Double],
821+
DateType -> classOf[java.lang.Integer],
822+
TimestampType -> classOf[java.lang.Long]
823+
)
824+
825+
def dataTypeJavaClass(dt: DataType): Class[_] = {
826+
dt match {
827+
case _: DecimalType => classOf[Decimal]
828+
case _: StructType => classOf[InternalRow]
829+
case _: ArrayType => classOf[ArrayData]
830+
case _: MapType => classOf[MapData]
831+
case ObjectType(cls) => cls
832+
case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object])
833+
}
834+
}
835+
836+
def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = {
837+
if (arguments != Nil) {
838+
arguments.map(e => dataTypeJavaClass(e.dataType))
839+
} else {
840+
Seq.empty
841+
}
842+
}
797843
}
798844

799845
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.objects
1919

20-
import java.lang.reflect.Modifier
20+
import java.lang.reflect.{Method, Modifier}
2121

2222
import scala.collection.JavaConverters._
2323
import scala.collection.mutable.Builder
@@ -28,7 +28,7 @@ import scala.util.Try
2828
import org.apache.spark.{SparkConf, SparkEnv}
2929
import org.apache.spark.serializer._
3030
import org.apache.spark.sql.Row
31-
import org.apache.spark.sql.catalyst.InternalRow
31+
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
3232
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
3333
import org.apache.spark.sql.catalyst.encoders.RowEncoder
3434
import org.apache.spark.sql.catalyst.expressions._
@@ -104,6 +104,38 @@ trait InvokeLike extends Expression with NonSQLExpression {
104104

105105
(argCode, argValues.mkString(", "), resultIsNull)
106106
}
107+
108+
/**
109+
* Evaluate each argument with a given row, invoke a method with a given object and arguments,
110+
* and cast a return value if the return type can be mapped to a Java Boxed type
111+
*
112+
* @param obj the object for the method to be called. If null, perform s static method call
113+
* @param method the method object to be called
114+
* @param arguments the arguments used for the method call
115+
* @param input the row used for evaluating arguments
116+
* @param dataType the data type of the return object
117+
* @return the return object of a method call
118+
*/
119+
def invoke(
120+
obj: Any,
121+
method: Method,
122+
arguments: Seq[Expression],
123+
input: InternalRow,
124+
dataType: DataType): Any = {
125+
val args = arguments.map(e => e.eval(input).asInstanceOf[Object])
126+
if (needNullCheck && args.exists(_ == null)) {
127+
// return null if one of arguments is null
128+
null
129+
} else {
130+
val ret = method.invoke(obj, args: _*)
131+
val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType)
132+
if (boxedClass.isDefined) {
133+
boxedClass.get.cast(ret)
134+
} else {
135+
ret
136+
}
137+
}
138+
}
107139
}
108140

109141
/**
@@ -264,12 +296,11 @@ case class Invoke(
264296
propagateNull: Boolean = true,
265297
returnNullable : Boolean = true) extends InvokeLike {
266298

299+
lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
300+
267301
override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
268302
override def children: Seq[Expression] = targetObject +: arguments
269303

270-
override def eval(input: InternalRow): Any =
271-
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
272-
273304
private lazy val encodedFunctionName = TermName(functionName).encodedName.toString
274305

275306
@transient lazy val method = targetObject.dataType match {
@@ -283,6 +314,21 @@ case class Invoke(
283314
case _ => None
284315
}
285316

317+
override def eval(input: InternalRow): Any = {
318+
val obj = targetObject.eval(input)
319+
if (obj == null) {
320+
// return null if obj is null
321+
null
322+
} else {
323+
val invokeMethod = if (method.isDefined) {
324+
method.get
325+
} else {
326+
obj.getClass.getDeclaredMethod(functionName, argClasses: _*)
327+
}
328+
invoke(obj, invokeMethod, arguments, input, dataType)
329+
}
330+
}
331+
286332
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
287333
val javaType = CodeGenerator.javaType(dataType)
288334
val obj = targetObject.genCode(ctx)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,23 @@ import org.apache.spark.{SparkConf, SparkFunSuite}
2424
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
2525
import org.apache.spark.sql.Row
2626
import org.apache.spark.sql.catalyst.InternalRow
27+
import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
2728
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
29+
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
2830
import org.apache.spark.sql.catalyst.expressions.objects._
2931
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
32+
import org.apache.spark.sql.internal.SQLConf
3033
import org.apache.spark.sql.types._
3134

35+
class InvokeTargetClass extends Serializable {
36+
def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0
37+
def filterPrimitiveInt(e: Int): Boolean = e > 0
38+
def binOp(e1: Int, e2: Double): Double = e1 + e2
39+
}
40+
41+
class InvokeTargetSubClass extends InvokeTargetClass {
42+
override def binOp(e1: Int, e2: Double): Double = e1 - e2
43+
}
3244

3345
class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
3446

@@ -81,6 +93,41 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
8193
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
8294
}
8395

96+
test("SPARK-23583: Invoke should support interpreted execution") {
97+
val targetObject = new InvokeTargetClass
98+
val funcClass = classOf[InvokeTargetClass]
99+
val funcObj = Literal.create(targetObject, ObjectType(funcClass))
100+
val targetSubObject = new InvokeTargetSubClass
101+
val funcSubObj = Literal.create(targetSubObject, ObjectType(classOf[InvokeTargetSubClass]))
102+
val funcNullObj = Literal.create(null, ObjectType(funcClass))
103+
104+
val inputInt = Seq(BoundReference(0, ObjectType(classOf[Any]), true))
105+
val inputPrimitiveInt = Seq(BoundReference(0, IntegerType, false))
106+
val inputSum = Seq(BoundReference(0, IntegerType, false), BoundReference(1, DoubleType, false))
107+
108+
checkObjectExprEvaluation(
109+
Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
110+
java.lang.Boolean.valueOf(true), InternalRow.fromSeq(Seq(Integer.valueOf(1))))
111+
112+
checkObjectExprEvaluation(
113+
Invoke(funcObj, "filterPrimitiveInt", BooleanType, inputPrimitiveInt),
114+
false, InternalRow.fromSeq(Seq(-1)))
115+
116+
checkObjectExprEvaluation(
117+
Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
118+
null, InternalRow.fromSeq(Seq(null)))
119+
120+
checkObjectExprEvaluation(
121+
Invoke(funcNullObj, "filterInt", ObjectType(classOf[Any]), inputInt),
122+
null, InternalRow.fromSeq(Seq(Integer.valueOf(1))))
123+
124+
checkObjectExprEvaluation(
125+
Invoke(funcObj, "binOp", DoubleType, inputSum), 1.25, InternalRow.apply(1, 0.25))
126+
127+
checkObjectExprEvaluation(
128+
Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25))
129+
}
130+
84131
test("SPARK-23585: UnwrapOption should support interpreted execution") {
85132
val cls = classOf[Option[Int]]
86133
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
@@ -105,6 +152,24 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
105152
checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq()))
106153
}
107154

155+
// by scala values instead of catalyst values.
156+
private def checkObjectExprEvaluation(
157+
expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
158+
val serializer = new JavaSerializer(new SparkConf()).newInstance
159+
val resolver = ResolveTimeZone(new SQLConf)
160+
val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
161+
checkEvaluationWithoutCodegen(expr, expected, inputRow)
162+
checkEvaluationWithGeneratedMutableProjection(expr, expected, inputRow)
163+
if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
164+
checkEvaluationWithUnsafeProjection(
165+
expr,
166+
expected,
167+
inputRow,
168+
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
169+
}
170+
checkEvaluationWithOptimization(expr, expected, inputRow)
171+
}
172+
108173
test("SPARK-23594 GetExternalRowField should support interpreted execution") {
109174
val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
110175
val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0")

0 commit comments

Comments
 (0)