Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address review comments
  • Loading branch information
kiszk committed Apr 4, 2018
commit 80579c5ebc0015b1a779b90c74b55fddefea4f0f
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -127,6 +128,52 @@ object CallMethodViaReflection {
StringType -> Seq(classOf[String])
)

val typeJavaMapping = Map[DataType, Class[_]](
BooleanType -> classOf[Boolean],
ByteType -> classOf[Byte],
ShortType -> classOf[Short],
IntegerType -> classOf[Int],
LongType -> classOf[Long],
FloatType -> classOf[Float],
DoubleType -> classOf[Double],
StringType -> classOf[UTF8String],
DateType -> classOf[DateType.InternalType],
TimestampType -> classOf[TimestampType.InternalType],
BinaryType -> classOf[BinaryType.InternalType],
CalendarIntervalType -> classOf[CalendarInterval]
)

val typeBoxedJavaMapping = Map[DataType, Class[_]](
BooleanType -> classOf[java.lang.Boolean],
ByteType -> classOf[java.lang.Byte],
ShortType -> classOf[java.lang.Short],
IntegerType -> classOf[java.lang.Integer],
LongType -> classOf[java.lang.Long],
FloatType -> classOf[java.lang.Float],
DoubleType -> classOf[java.lang.Double],
DateType -> classOf[java.lang.Integer],
TimestampType -> classOf[java.lang.Long]
)

def dataTypeJavaClass(dt: DataType): Class[_] = {
dt match {
case _: DecimalType => classOf[Decimal]
case _: StructType => classOf[InternalRow]
case _: ArrayType => classOf[ArrayData]
case _: MapType => classOf[MapData]
case ObjectType(cls) => cls
case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object])
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above should be in CallMethodViaReflection or CodeGenerator?


def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = {
if (arguments != Nil) {
arguments.map(e => dataTypeJavaClass(e.dataType))
} else {
Seq.empty
}
}

/**
* Returns true if the class can be found and loaded.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,20 +222,31 @@ case class StaticInvoke(
override def nullable: Boolean = needNullCheck || returnNullable
override def children: Seq[Expression] = arguments



override def eval(input: InternalRow): Any = {
if (staticObject == null) {
throw new RuntimeException("The static class cannot be null.")
val args = arguments.map(e => e.eval(input).asInstanceOf[Object])
val argClasses = CallMethodViaReflection.expressionJavaClasses(arguments)
val cls = if (staticObject.getName == objectName) {
staticObject
} else {
Utils.classForName(objectName)
}
val method = cls.getDeclaredMethod(functionName, argClasses : _*)
if (needNullCheck && args.exists(_ == null)) {
// return null if one of arguments is null
null
} else {
val ret = method.invoke(null, args: _*)

val parmTypes = arguments.map(e =>
CallMethodViaReflection.typeMapping.getOrElse(e.dataType,
Seq(e.dataType.asInstanceOf[ObjectType].cls))(0))
val parms = arguments.map(e => e.eval(input).asInstanceOf[Object])
val method = staticObject.getDeclaredMethod(functionName, parmTypes : _*)
val ret = method.invoke(null, parms : _*)
val retClass = CallMethodViaReflection.typeMapping.getOrElse(dataType,
Seq(dataType.asInstanceOf[ObjectType].cls))(0)
retClass.cast(ret)
if (CodeGenerator.defaultValue(dataType) == "null") {
ret
} else {
// cast a primitive value using Boxed class
val boxedClass = CallMethodViaReflection.typeBoxedJavaMapping(dataType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we can directly do like this without the if.

val boxedClass = CallMethodViaReflection.typeBoxedJavaMapping.get(dataType)
boxedClass.map(_.cast(ret)).getOrElse(ret)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. When discussions related to typeJavaMapping, typeBoxedJavaMapping`, and others are fixed, I will address this.

boxedClass.cast(ret)
}
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,22 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import java.sql.{Date, Timestamp}

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.SparkFunSuite
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class InvokeTargetClass extends Serializable {
def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0
Expand Down Expand Up @@ -93,6 +98,66 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
}

test("SPARK-23582: StaticInvoke should support interpreted execution") {
Seq((classOf[java.lang.Boolean], "true", true),
(classOf[java.lang.Byte], "1", 1.toByte),
(classOf[java.lang.Short], "257", 257.toShort),
(classOf[java.lang.Integer], "12345", 12345),
(classOf[java.lang.Long], "12345678", 12345678.toLong),
(classOf[java.lang.Float], "12.34", 12.34.toFloat),
(classOf[java.lang.Double], "1.2345678", 1.2345678)
).foreach { case (cls, arg, expected) =>
checkObjectExprEvaluation(StaticInvoke(cls, ObjectType(cls), "valueOf",
Seq(BoundReference(0, ObjectType(classOf[java.lang.String]), true))),
expected, InternalRow.fromSeq(Seq(arg)))
}

// Return null when null argument is passed with propagateNull = true
val stringCls = classOf[java.lang.String]
checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf",
Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = true),
null, InternalRow.fromSeq(Seq(null)))
checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf",
Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = false),
"null", InternalRow.fromSeq(Seq(null)))

// test no argument
val clCls = classOf[java.lang.ClassLoader]
checkObjectExprEvaluation(StaticInvoke(clCls, ObjectType(clCls), "getSystemClassLoader", Nil),
ClassLoader.getSystemClassLoader, InternalRow.empty)
// test more than one argument
val intCls = classOf[java.lang.Integer]
checkObjectExprEvaluation(StaticInvoke(intCls, ObjectType(intCls), "compare",
Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, false))),
0, InternalRow.fromSeq(Seq(7, 7)))

Seq((DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", ObjectType(classOf[Timestamp]),
new Timestamp(77777), DateTimeUtils.fromJavaTimestamp(new Timestamp(77777))),
(DateTimeUtils.getClass, DateType, "fromJavaDate", ObjectType(classOf[Date]),
new Date(88888888), DateTimeUtils.fromJavaDate(new Date(88888888))),
(classOf[UTF8String], StringType, "fromString", ObjectType(classOf[String]),
"abc", UTF8String.fromString("abc")),
(Decimal.getClass, DecimalType(38, 0), "fromDecimal", ObjectType(classOf[Any]),
BigInt(88888888), Decimal.fromDecimal(BigInt(88888888))),
(Decimal.getClass, DecimalType.SYSTEM_DEFAULT,
"apply", ObjectType(classOf[java.math.BigInteger]),
new java.math.BigInteger("88888888"), Decimal.apply(new java.math.BigInteger("88888888"))),
(classOf[ArrayData], ArrayType(IntegerType), "toArrayData", ObjectType(classOf[Any]),
Array[Int](1, 2, 3), ArrayData.toArrayData(Array[Int](1, 2, 3))),
(classOf[UnsafeArrayData], ArrayType(IntegerType, false),
"fromPrimitiveArray", ObjectType(classOf[Array[Int]]),
Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))),
(DateTimeUtils.getClass, ObjectType(classOf[Date]),
"toJavaDate", ObjectType(classOf[SQLDate]), 77777, DateTimeUtils.toJavaDate(77777)),
(DateTimeUtils.getClass, ObjectType(classOf[Timestamp]),
"toJavaTimestamp", ObjectType(classOf[SQLTimestamp]),
88888888L, DateTimeUtils.toJavaTimestamp(88888888L))
).foreach { case (cls, dataType, methodName, argType, arg, expected) =>
checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName,
Seq(BoundReference(0, argType, true))), expected, InternalRow.fromSeq(Seq(arg)))
}
}

test("SPARK-23583: Invoke should support interpreted execution") {
val targetObject = new InvokeTargetClass
val funcClass = classOf[InvokeTargetClass]
Expand Down