From 9fea286e9c7d47bf41b0f82ec44d55beced167e0 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 7 Mar 2018 00:30:40 +0000 Subject: [PATCH 01/13] initial prototype --- .../expressions/objects/objects.scala | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a455c1c821a2..1d7dffe4f15e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -24,7 +24,6 @@ import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag import scala.util.Try - import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row @@ -32,9 +31,10 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]]. @@ -221,8 +221,21 @@ case class StaticInvoke( override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + override def eval(input: InternalRow): Any = { + if (staticObject == null) { + throw new RuntimeException("The static class cannot be null.") + } + + val parmTypes = arguments.map(e => + CallMethodViaReflection.typeMapping.getOrElse(e.dataType, + Seq(e.dataType.asInstanceOf[ObjectType].cls))(0)) + val parms = arguments.map(e => e.eval(input).asInstanceOf[Object]) + val method = staticObject.getDeclaredMethod(functionName, parmTypes : _*) + val ret = method.invoke(null, parms : _*) + val retClass = CallMethodViaReflection.typeMapping.getOrElse(dataType, + Seq(dataType.asInstanceOf[ObjectType].cls))(0) + retClass.cast(ret) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) From 76c61ade7a6af014e9e136150c5dc412a7f501d5 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 7 Mar 2018 01:06:50 +0000 Subject: [PATCH 02/13] fix style error --- .../spark/sql/catalyst/expressions/objects/objects.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 1d7dffe4f15e..15c4d467e1fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag import scala.util.Try + import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row @@ -31,7 +32,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils From 80579c5ebc0015b1a779b90c74b55fddefea4f0f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 8 Mar 2018 13:24:09 +0000 Subject: [PATCH 03/13] address review comments --- .../expressions/CallMethodViaReflection.scala | 49 +++++++++++++- .../expressions/objects/objects.scala | 33 ++++++--- .../expressions/ObjectExpressionsSuite.scala | 67 ++++++++++++++++++- 3 files changed, 136 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 65bb9a8c642b..bc523caaae20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils /** @@ -127,6 +128,52 @@ object CallMethodViaReflection { StringType -> Seq(classOf[String]) ) + val typeJavaMapping = Map[DataType, Class[_]]( + BooleanType -> classOf[Boolean], + ByteType -> classOf[Byte], + ShortType -> classOf[Short], + IntegerType -> classOf[Int], + LongType -> classOf[Long], + FloatType -> classOf[Float], + DoubleType -> classOf[Double], + StringType -> classOf[UTF8String], + DateType -> classOf[DateType.InternalType], + TimestampType -> classOf[TimestampType.InternalType], + BinaryType -> classOf[BinaryType.InternalType], + CalendarIntervalType -> classOf[CalendarInterval] + ) + + val typeBoxedJavaMapping = Map[DataType, Class[_]]( + BooleanType -> classOf[java.lang.Boolean], + ByteType -> classOf[java.lang.Byte], + ShortType -> classOf[java.lang.Short], + IntegerType -> classOf[java.lang.Integer], + LongType -> classOf[java.lang.Long], + FloatType -> classOf[java.lang.Float], + DoubleType -> classOf[java.lang.Double], + DateType -> classOf[java.lang.Integer], + TimestampType -> classOf[java.lang.Long] + ) + + def dataTypeJavaClass(dt: DataType): Class[_] = { + dt match { + case _: DecimalType => classOf[Decimal] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] + case ObjectType(cls) => cls + case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + } + + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { + if (arguments != Nil) { + arguments.map(e => dataTypeJavaClass(e.dataType)) + } else { + Seq.empty + } + } + /** * Returns true if the class can be found and loaded. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 15c4d467e1fe..2ddd3f8bd59e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -222,20 +222,31 @@ case class StaticInvoke( override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments + + override def eval(input: InternalRow): Any = { - if (staticObject == null) { - throw new RuntimeException("The static class cannot be null.") + val args = arguments.map(e => e.eval(input).asInstanceOf[Object]) + val argClasses = CallMethodViaReflection.expressionJavaClasses(arguments) + val cls = if (staticObject.getName == objectName) { + staticObject + } else { + Utils.classForName(objectName) } + val method = cls.getDeclaredMethod(functionName, argClasses : _*) + if (needNullCheck && args.exists(_ == null)) { + // return null if one of arguments is null + null + } else { + val ret = method.invoke(null, args: _*) - val parmTypes = arguments.map(e => - CallMethodViaReflection.typeMapping.getOrElse(e.dataType, - Seq(e.dataType.asInstanceOf[ObjectType].cls))(0)) - val parms = arguments.map(e => e.eval(input).asInstanceOf[Object]) - val method = staticObject.getDeclaredMethod(functionName, parmTypes : _*) - val ret = method.invoke(null, parms : _*) - val retClass = CallMethodViaReflection.typeMapping.getOrElse(dataType, - Seq(dataType.asInstanceOf[ObjectType].cls))(0) - retClass.cast(ret) + if (CodeGenerator.defaultValue(dataType) == "null") { + ret + } else { + // cast a primitive value using Boxed class + val boxedClass = CallMethodViaReflection.typeBoxedJavaMapping(dataType) + boxedClass.cast(ret) + } + } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 9bfe2916b082..4cba0a16c4f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -20,7 +20,10 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import java.sql.{Date, Timestamp} + import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -28,9 +31,11 @@ import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class InvokeTargetClass extends Serializable { def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 @@ -93,6 +98,66 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed } + test("SPARK-23582: StaticInvoke should support interpreted execution") { + Seq((classOf[java.lang.Boolean], "true", true), + (classOf[java.lang.Byte], "1", 1.toByte), + (classOf[java.lang.Short], "257", 257.toShort), + (classOf[java.lang.Integer], "12345", 12345), + (classOf[java.lang.Long], "12345678", 12345678.toLong), + (classOf[java.lang.Float], "12.34", 12.34.toFloat), + (classOf[java.lang.Double], "1.2345678", 1.2345678) + ).foreach { case (cls, arg, expected) => + checkObjectExprEvaluation(StaticInvoke(cls, ObjectType(cls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[java.lang.String]), true))), + expected, InternalRow.fromSeq(Seq(arg))) + } + + // Return null when null argument is passed with propagateNull = true + val stringCls = classOf[java.lang.String] + checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = true), + null, InternalRow.fromSeq(Seq(null))) + checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = false), + "null", InternalRow.fromSeq(Seq(null))) + + // test no argument + val clCls = classOf[java.lang.ClassLoader] + checkObjectExprEvaluation(StaticInvoke(clCls, ObjectType(clCls), "getSystemClassLoader", Nil), + ClassLoader.getSystemClassLoader, InternalRow.empty) + // test more than one argument + val intCls = classOf[java.lang.Integer] + checkObjectExprEvaluation(StaticInvoke(intCls, ObjectType(intCls), "compare", + Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, false))), + 0, InternalRow.fromSeq(Seq(7, 7))) + + Seq((DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", ObjectType(classOf[Timestamp]), + new Timestamp(77777), DateTimeUtils.fromJavaTimestamp(new Timestamp(77777))), + (DateTimeUtils.getClass, DateType, "fromJavaDate", ObjectType(classOf[Date]), + new Date(88888888), DateTimeUtils.fromJavaDate(new Date(88888888))), + (classOf[UTF8String], StringType, "fromString", ObjectType(classOf[String]), + "abc", UTF8String.fromString("abc")), + (Decimal.getClass, DecimalType(38, 0), "fromDecimal", ObjectType(classOf[Any]), + BigInt(88888888), Decimal.fromDecimal(BigInt(88888888))), + (Decimal.getClass, DecimalType.SYSTEM_DEFAULT, + "apply", ObjectType(classOf[java.math.BigInteger]), + new java.math.BigInteger("88888888"), Decimal.apply(new java.math.BigInteger("88888888"))), + (classOf[ArrayData], ArrayType(IntegerType), "toArrayData", ObjectType(classOf[Any]), + Array[Int](1, 2, 3), ArrayData.toArrayData(Array[Int](1, 2, 3))), + (classOf[UnsafeArrayData], ArrayType(IntegerType, false), + "fromPrimitiveArray", ObjectType(classOf[Array[Int]]), + Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))), + (DateTimeUtils.getClass, ObjectType(classOf[Date]), + "toJavaDate", ObjectType(classOf[SQLDate]), 77777, DateTimeUtils.toJavaDate(77777)), + (DateTimeUtils.getClass, ObjectType(classOf[Timestamp]), + "toJavaTimestamp", ObjectType(classOf[SQLTimestamp]), + 88888888L, DateTimeUtils.toJavaTimestamp(88888888L)) + ).foreach { case (cls, dataType, methodName, argType, arg, expected) => + checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName, + Seq(BoundReference(0, argType, true))), expected, InternalRow.fromSeq(Seq(arg))) + } + } + test("SPARK-23583: Invoke should support interpreted execution") { val targetObject = new InvokeTargetClass val funcClass = classOf[InvokeTargetClass] From d02e926cb165b57519da5a7f2a2e25ac4b06967f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 8 Mar 2018 13:44:39 +0000 Subject: [PATCH 04/13] fix scala style error --- .../spark/sql/catalyst/expressions/objects/objects.scala | 2 -- .../sql/catalyst/expressions/ObjectExpressionsSuite.scala | 5 ++--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 2ddd3f8bd59e..c62ee83700be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -222,8 +222,6 @@ case class StaticInvoke( override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments - - override def eval(input: InternalRow): Any = { val args = arguments.map(e => e.eval(input).asInstanceOf[Object]) val argClasses = CallMethodViaReflection.expressionJavaClasses(arguments) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 4cba0a16c4f5..5369a719c9e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} + import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import java.sql.{Date, Timestamp} - import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.SparkFunSuite import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow From b1f08c1b23d901ac10486d32a8ac0e3489cdc3e2 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 8 Mar 2018 16:40:06 +0000 Subject: [PATCH 05/13] fix test failures import checkObjectExprEvaluation for test --- .../sql/catalyst/expressions/ObjectExpressionsSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 5369a719c9e0..9f880420b06c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -150,7 +150,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "toJavaDate", ObjectType(classOf[SQLDate]), 77777, DateTimeUtils.toJavaDate(77777)), (DateTimeUtils.getClass, ObjectType(classOf[Timestamp]), "toJavaTimestamp", ObjectType(classOf[SQLTimestamp]), - 88888888L, DateTimeUtils.toJavaTimestamp(88888888L)) + 88888888.toLong, DateTimeUtils.toJavaTimestamp(88888888)) ).foreach { case (cls, dataType, methodName, argType, arg, expected) => checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName, Seq(BoundReference(0, argType, true))), expected, InternalRow.fromSeq(Seq(arg))) @@ -216,6 +216,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq())) } + // This is an alternative version of `checkEvaluation` to compare results // by scala values instead of catalyst values. private def checkObjectExprEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { From 7d66bb96c9d661763aff1b28457b3b69a636f739 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 26 Mar 2018 10:18:48 +0100 Subject: [PATCH 06/13] move utility methods to ScalaReflection --- .../expressions/CallMethodViaReflection.scala | 49 +------------------ .../expressions/objects/objects.scala | 7 ++- 2 files changed, 4 insertions(+), 52 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index bc523caaae20..65bb9a8c642b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -128,52 +127,6 @@ object CallMethodViaReflection { StringType -> Seq(classOf[String]) ) - val typeJavaMapping = Map[DataType, Class[_]]( - BooleanType -> classOf[Boolean], - ByteType -> classOf[Byte], - ShortType -> classOf[Short], - IntegerType -> classOf[Int], - LongType -> classOf[Long], - FloatType -> classOf[Float], - DoubleType -> classOf[Double], - StringType -> classOf[UTF8String], - DateType -> classOf[DateType.InternalType], - TimestampType -> classOf[TimestampType.InternalType], - BinaryType -> classOf[BinaryType.InternalType], - CalendarIntervalType -> classOf[CalendarInterval] - ) - - val typeBoxedJavaMapping = Map[DataType, Class[_]]( - BooleanType -> classOf[java.lang.Boolean], - ByteType -> classOf[java.lang.Byte], - ShortType -> classOf[java.lang.Short], - IntegerType -> classOf[java.lang.Integer], - LongType -> classOf[java.lang.Long], - FloatType -> classOf[java.lang.Float], - DoubleType -> classOf[java.lang.Double], - DateType -> classOf[java.lang.Integer], - TimestampType -> classOf[java.lang.Long] - ) - - def dataTypeJavaClass(dt: DataType): Class[_] = { - dt match { - case _: DecimalType => classOf[Decimal] - case _: StructType => classOf[InternalRow] - case _: ArrayType => classOf[ArrayData] - case _: MapType => classOf[MapData] - case ObjectType(cls) => cls - case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object]) - } - } - - def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { - if (arguments != Nil) { - arguments.map(e => dataTypeJavaClass(e.dataType)) - } else { - Seq.empty - } - } - /** * Returns true if the class can be found and loaded. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index c62ee83700be..823c0376607d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -24,7 +24,6 @@ import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag import scala.util.Try - import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row @@ -32,7 +31,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -224,7 +223,7 @@ case class StaticInvoke( override def eval(input: InternalRow): Any = { val args = arguments.map(e => e.eval(input).asInstanceOf[Object]) - val argClasses = CallMethodViaReflection.expressionJavaClasses(arguments) + val argClasses = ScalaReflection.expressionJavaClasses(arguments) val cls = if (staticObject.getName == objectName) { staticObject } else { @@ -241,7 +240,7 @@ case class StaticInvoke( ret } else { // cast a primitive value using Boxed class - val boxedClass = CallMethodViaReflection.typeBoxedJavaMapping(dataType) + val boxedClass = ScalaReflection.typeBoxedJavaMapping(dataType) boxedClass.cast(ret) } } From c66e4527b91ab8410df40c3f7551d33feb0e5713 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 26 Mar 2018 10:30:16 +0100 Subject: [PATCH 07/13] fix style error --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 823c0376607d..37e93c707a1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag import scala.util.Try + import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row From 1b3ea2973ad9db51f01ada1650ad7dad30beb04c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 26 Mar 2018 11:09:42 +0100 Subject: [PATCH 08/13] fix style error --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 37e93c707a1e..29c9c76fe68d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils From 7279557b032f26f079b67f07f2f87e981f6a1465 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 27 Mar 2018 17:58:32 +0100 Subject: [PATCH 09/13] address review comments --- .../expressions/objects/objects.scala | 31 ++++++------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 29c9c76fe68d..90acf1717c9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -218,33 +218,20 @@ case class StaticInvoke( returnNullable: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") + val argClasses = ScalaReflection.expressionJavaClasses(arguments) + val cls = if (staticObject.getName == objectName) { + staticObject + } else { + Utils.classForName(objectName) + } override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments - override def eval(input: InternalRow): Any = { - val args = arguments.map(e => e.eval(input).asInstanceOf[Object]) - val argClasses = ScalaReflection.expressionJavaClasses(arguments) - val cls = if (staticObject.getName == objectName) { - staticObject - } else { - Utils.classForName(objectName) - } - val method = cls.getDeclaredMethod(functionName, argClasses : _*) - if (needNullCheck && args.exists(_ == null)) { - // return null if one of arguments is null - null - } else { - val ret = method.invoke(null, args: _*) + @transient lazy val method = cls.getDeclaredMethod(functionName, argClasses : _*) - if (CodeGenerator.defaultValue(dataType) == "null") { - ret - } else { - // cast a primitive value using Boxed class - val boxedClass = ScalaReflection.typeBoxedJavaMapping(dataType) - boxedClass.cast(ret) - } - } + override def eval(input: InternalRow): Any = { + invoke(null, method, arguments, input, dataType) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { From 0530dcc2c32eb7ebcd86d66ede6f3fad9a6f797f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 28 Mar 2018 04:36:21 +0100 Subject: [PATCH 10/13] fix test failures --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 90acf1717c9d..3fa91bd36bb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -218,7 +218,6 @@ case class StaticInvoke( returnNullable: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") - val argClasses = ScalaReflection.expressionJavaClasses(arguments) val cls = if (staticObject.getName == objectName) { staticObject } else { @@ -228,6 +227,7 @@ case class StaticInvoke( override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments + lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) @transient lazy val method = cls.getDeclaredMethod(functionName, argClasses : _*) override def eval(input: InternalRow): Any = { From 0c61e6063e4c28c2b6928fc74ddf959a1553bafc Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 4 Apr 2018 19:54:04 +0100 Subject: [PATCH 11/13] minimize changes --- .../sql/catalyst/expressions/ObjectExpressionsSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 9f880420b06c..77c62b512575 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -216,7 +216,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq())) } - // This is an alternative version of `checkEvaluation` to compare results // by scala values instead of catalyst values. private def checkObjectExprEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { From 984cee7b07e2773fc3a07017516ce0884c752b8b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 4 Apr 2018 20:38:12 +0100 Subject: [PATCH 12/13] fix build failure --- .../sql/catalyst/expressions/ObjectExpressionsSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 77c62b512575..37b2744ec078 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import scala.reflect.ClassTag - import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row @@ -30,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ From 7b30c30a68107880cda108e6a0ed923c27ac6d56 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 5 Apr 2018 03:47:49 +0100 Subject: [PATCH 13/13] fix scala style error --- .../spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 37b2744ec078..1d59b20077fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import scala.reflect.ClassTag + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row