Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

Expand Down Expand Up @@ -794,6 +794,52 @@ object ScalaReflection extends ScalaReflection {
"interface", "long", "native", "new", "null", "package", "private", "protected", "public",
"return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw",
"throws", "transient", "true", "try", "void", "volatile", "while")

val typeJavaMapping = Map[DataType, Class[_]](
BooleanType -> classOf[Boolean],
ByteType -> classOf[Byte],
ShortType -> classOf[Short],
IntegerType -> classOf[Int],
LongType -> classOf[Long],
FloatType -> classOf[Float],
DoubleType -> classOf[Double],
StringType -> classOf[UTF8String],
DateType -> classOf[DateType.InternalType],
TimestampType -> classOf[TimestampType.InternalType],
BinaryType -> classOf[BinaryType.InternalType],
CalendarIntervalType -> classOf[CalendarInterval]
)

val typeBoxedJavaMapping = Map[DataType, Class[_]](
BooleanType -> classOf[java.lang.Boolean],
ByteType -> classOf[java.lang.Byte],
ShortType -> classOf[java.lang.Short],
IntegerType -> classOf[java.lang.Integer],
LongType -> classOf[java.lang.Long],
FloatType -> classOf[java.lang.Float],
DoubleType -> classOf[java.lang.Double],
DateType -> classOf[java.lang.Integer],
TimestampType -> classOf[java.lang.Long]
)

def dataTypeJavaClass(dt: DataType): Class[_] = {
dt match {
case _: DecimalType => classOf[Decimal]
case _: StructType => classOf[InternalRow]
case _: ArrayType => classOf[ArrayData]
case _: MapType => classOf[MapData]
case ObjectType(cls) => cls
case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object])
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @maropu We can now use this instead of CallMethodViaReflection.typeMapping.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = {
if (arguments != Nil) {
arguments.map(e => dataTypeJavaClass(e.dataType))
} else {
Seq.empty
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.objects

import java.lang.reflect.Modifier
import java.lang.reflect.{Method, Modifier}

import scala.collection.JavaConverters._
import scala.collection.mutable.Builder
Expand All @@ -28,7 +28,7 @@ import scala.util.Try
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -104,6 +104,38 @@ trait InvokeLike extends Expression with NonSQLExpression {

(argCode, argValues.mkString(", "), resultIsNull)
}

/**
* Evaluate each argument with a given row, invoke a method with a given object and arguments,
* and cast a return value if the return type can be mapped to a Java Boxed type
*
* @param obj the object for the method to be called. If null, perform s static method call
* @param method the method object to be called
* @param arguments the arguments used for the method call
* @param input the row used for evaluating arguments
* @param dataType the data type of the return object
* @return the return object of a method call
*/
def invoke(
obj: Any,
method: Method,
arguments: Seq[Expression],
input: InternalRow,
dataType: DataType): Any = {
val args = arguments.map(e => e.eval(input).asInstanceOf[Object])
if (needNullCheck && args.exists(_ == null)) {
// return null if one of arguments is null
null
} else {
val ret = method.invoke(obj, args: _*)
val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType)
if (boxedClass.isDefined) {
boxedClass.get.cast(ret)
} else {
ret
}
}
}
}

/**
Expand Down Expand Up @@ -264,12 +296,11 @@ case class Invoke(
propagateNull: Boolean = true,
returnNullable : Boolean = true) extends InvokeLike {

lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)

override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
override def children: Seq[Expression] = targetObject +: arguments

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")

private lazy val encodedFunctionName = TermName(functionName).encodedName.toString

@transient lazy val method = targetObject.dataType match {
Expand All @@ -283,6 +314,21 @@ case class Invoke(
case _ => None
}

override def eval(input: InternalRow): Any = {
val obj = targetObject.eval(input)
if (obj == null) {
// return null if obj is null
null
} else {
val invokeMethod = if (method.isDefined) {
method.get
} else {
obj.getClass.getDeclaredMethod(functionName, argClasses: _*)
}
invoke(obj, invokeMethod, arguments, input, dataType)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = CodeGenerator.javaType(dataType)
val obj = targetObject.genCode(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,23 @@ import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

class InvokeTargetClass extends Serializable {
def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0
def filterPrimitiveInt(e: Int): Boolean = e > 0
def binOp(e1: Int, e2: Double): Double = e1 + e2
}

class InvokeTargetSubClass extends InvokeTargetClass {
override def binOp(e1: Int, e2: Double): Double = e1 - e2
}

class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

Expand Down Expand Up @@ -81,6 +93,41 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
}

test("SPARK-23583: Invoke should support interpreted execution") {
val targetObject = new InvokeTargetClass
val funcClass = classOf[InvokeTargetClass]
val funcObj = Literal.create(targetObject, ObjectType(funcClass))
val targetSubObject = new InvokeTargetSubClass
val funcSubObj = Literal.create(targetSubObject, ObjectType(classOf[InvokeTargetSubClass]))
val funcNullObj = Literal.create(null, ObjectType(funcClass))

val inputInt = Seq(BoundReference(0, ObjectType(classOf[Any]), true))
val inputPrimitiveInt = Seq(BoundReference(0, IntegerType, false))
val inputSum = Seq(BoundReference(0, IntegerType, false), BoundReference(1, DoubleType, false))

checkObjectExprEvaluation(
Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
java.lang.Boolean.valueOf(true), InternalRow.fromSeq(Seq(Integer.valueOf(1))))

checkObjectExprEvaluation(
Invoke(funcObj, "filterPrimitiveInt", BooleanType, inputPrimitiveInt),
false, InternalRow.fromSeq(Seq(-1)))

checkObjectExprEvaluation(
Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
null, InternalRow.fromSeq(Seq(null)))

checkObjectExprEvaluation(
Invoke(funcNullObj, "filterInt", ObjectType(classOf[Any]), inputInt),
null, InternalRow.fromSeq(Seq(Integer.valueOf(1))))

checkObjectExprEvaluation(
Invoke(funcObj, "binOp", DoubleType, inputSum), 1.25, InternalRow.apply(1, 0.25))

checkObjectExprEvaluation(
Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25))
}

test("SPARK-23585: UnwrapOption should support interpreted execution") {
val cls = classOf[Option[Int]]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
Expand All @@ -105,6 +152,24 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq()))
}

// by scala values instead of catalyst values.
private def checkObjectExprEvaluation(
expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
val serializer = new JavaSerializer(new SparkConf()).newInstance
val resolver = ResolveTimeZone(new SQLConf)
val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
checkEvaluationWithoutCodegen(expr, expected, inputRow)
checkEvaluationWithGeneratedMutableProjection(expr, expected, inputRow)
if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
checkEvaluationWithUnsafeProjection(
expr,
expected,
inputRow,
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
}
checkEvaluationWithOptimization(expr, expected, inputRow)
}

test("SPARK-23594 GetExternalRowField should support interpreted execution") {
val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0")
Expand Down