From 78de6c8fb7be454f1f0dcdb49a460c4fa086fcd9 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 7 Mar 2018 02:06:28 +0900 Subject: [PATCH 1/7] ValidateExternalType should support interpreted execution --- .../spark/sql/catalyst/ScalaReflection.scala | 13 +++++++ .../expressions/objects/objects.scala | 32 +++++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 34 +++++++++++++++++-- 3 files changed, 74 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 818cc2fb1e8a..fc3d7fdaa5cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -125,6 +125,19 @@ object ScalaReflection extends ScalaReflection { case _ => false } + def classForNativeTypeOf(dt: DataType): Class[_] = dt match { + case NullType => classOf[Object] + case BooleanType => classOf[java.lang.Boolean] + case ByteType => classOf[java.lang.Byte] + case ShortType => classOf[java.lang.Short] + case IntegerType => classOf[java.lang.Integer] + case LongType => classOf[java.lang.Long] + case FloatType => classOf[java.lang.Float] + case DoubleType => classOf[java.lang.Double] + case BinaryType => classOf[Array[Byte]] + case CalendarIntervalType => classOf[CalendarInterval] + } + /** * Returns an expression that can be used to deserialize an input row to an object of type `T` * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index bc17d1229420..9cb5307dd543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1672,11 +1672,37 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected) - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") - private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + private lazy val checkType = expected match { + case _: DecimalType => + (value: Any) => { + Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) + .exists { x => value.getClass.isAssignableFrom(x) } + } + case _: ArrayType => + (value: Any) => { + value.getClass.isAssignableFrom(classOf[Seq[_]]) || value.getClass.isArray + } + case _ if ScalaReflection.isNativeType(expected) => + (value: Any) => { + value.getClass.isAssignableFrom(ScalaReflection.classForNativeTypeOf(expected)) + } + case _ => + (value: Any) => { + value.getClass.isAssignableFrom(dataType.asInstanceOf[ObjectType].cls) + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (checkType(result)) { + result + } else { + throw new RuntimeException(s"${result.getClass.getName}$errMsg") + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the type doesn't match. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index bcd035c1eba0..ca6c4c3a285f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class InvokeTargetClass extends Serializable { def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 @@ -274,6 +274,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq())) } + // This is an alternative version of `checkEvaluation` to compare results // by scala values instead of catalyst values. private def checkObjectExprEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { @@ -296,7 +297,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) => - checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) + checkObjectExprEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) } // If an input row or a field are null, a runtime exception will be thrown @@ -472,6 +473,35 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val deserializer = toMapExpr.copy(inputData = Literal.create(data)) checkObjectExprEvaluation(deserializer, expected = data) } + + test("SPARK-23595 ValidateExternalType should support interpreted execution") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + Seq( + (true, BooleanType), + (2.toByte, ByteType), + (5.toShort, ShortType), + (23, IntegerType), + (61L, LongType), + (1.0f, FloatType), + (10.0, DoubleType), + ("abcd".getBytes, BinaryType), + ("abcd", StringType), + (BigDecimal.valueOf(10), DecimalType.IntDecimal), + (CalendarInterval.fromString("interval 3 day"), CalendarIntervalType), + (java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal), + (Array(3, 2, 1), ArrayType(IntegerType)) + ).foreach { case (input, dt) => + val validateType = ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt) + checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input)))) + } + + checkExceptionInExpression[RuntimeException]( + ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType), + InternalRow.fromSeq(Seq(Row(1))), + "java.lang.Integer is not a valid external type for schema of double") + } } class TestBean extends Serializable { From 4bee661848f85144ec6672e1f6aefb6802b5cbe3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 7 Mar 2018 23:17:03 +0900 Subject: [PATCH 2/7] Fix --- .../expressions/objects/objects.scala | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9cb5307dd543..cf239222afcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1670,27 +1670,30 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def nullable: Boolean = child.nullable - override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected) + override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + private lazy val dataTypeClazz = if (dataType.isInstanceOf[ObjectType]) { + dataType.asInstanceOf[ObjectType].cls + } else { + // Some external types (e.g., native types and `PythonUserDefinedType`) might not be ObjectType + ScalaReflection.classForNativeTypeOf(dataType) + } + private lazy val checkType = expected match { case _: DecimalType => (value: Any) => { - Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) - .exists { x => value.getClass.isAssignableFrom(x) } + value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] || + value.isInstanceOf[Decimal] } case _: ArrayType => (value: Any) => { - value.getClass.isAssignableFrom(classOf[Seq[_]]) || value.getClass.isArray - } - case _ if ScalaReflection.isNativeType(expected) => - (value: Any) => { - value.getClass.isAssignableFrom(ScalaReflection.classForNativeTypeOf(expected)) + value.getClass.isArray || value.isInstanceOf[Seq[_]] } case _ => (value: Any) => { - value.getClass.isAssignableFrom(dataType.asInstanceOf[ObjectType].cls) + dataTypeClazz.isInstance(value) } } @@ -1715,7 +1718,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") case _: ArrayType => - s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" + s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}" case _ => s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } From 647df9fefb7f151d8aa1f804a7aa57908031df14 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 8 Mar 2018 21:42:40 +0900 Subject: [PATCH 3/7] Fix --- .../catalyst/expressions/objects/objects.scala | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index cf239222afcb..e3f8393da858 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1674,14 +1674,8 @@ case class ValidateExternalType(child: Expression, expected: DataType) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" - private lazy val dataTypeClazz = if (dataType.isInstanceOf[ObjectType]) { - dataType.asInstanceOf[ObjectType].cls - } else { - // Some external types (e.g., native types and `PythonUserDefinedType`) might not be ObjectType - ScalaReflection.classForNativeTypeOf(dataType) - } - private lazy val checkType = expected match { + private lazy val checkType: (Any) => Boolean = expected match { case _: DecimalType => (value: Any) => { value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] || @@ -1692,6 +1686,13 @@ case class ValidateExternalType(child: Expression, expected: DataType) value.getClass.isArray || value.isInstanceOf[Seq[_]] } case _ => + val dataTypeClazz = if (dataType.isInstanceOf[ObjectType]) { + dataType.asInstanceOf[ObjectType].cls + } else { + // Some external types (e.g., native types and `PythonUserDefinedType`) + // might not be ObjectType + ScalaReflection.classForNativeTypeOf(dataType) + } (value: Any) => { dataTypeClazz.isInstance(value) } From 9b6b3146303c1c2bad7a32849fae7cda507ad451 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 9 Mar 2018 12:24:05 +0900 Subject: [PATCH 4/7] Fix --- .../spark/sql/catalyst/ScalaReflection.scala | 13 ---------- .../sql/catalyst/encoders/RowEncoder.scala | 24 +++++++++++++++++-- .../expressions/objects/objects.scala | 9 +------ 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index fc3d7fdaa5cf..818cc2fb1e8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -125,19 +125,6 @@ object ScalaReflection extends ScalaReflection { case _ => false } - def classForNativeTypeOf(dt: DataType): Class[_] = dt match { - case NullType => classOf[Object] - case BooleanType => classOf[java.lang.Boolean] - case ByteType => classOf[java.lang.Byte] - case ShortType => classOf[java.lang.Short] - case IntegerType => classOf[java.lang.Integer] - case LongType => classOf[java.lang.Long] - case FloatType => classOf[java.lang.Float] - case DoubleType => classOf[java.lang.Double] - case BinaryType => classOf[Array[Byte]] - case CalendarIntervalType => classOf[CalendarInterval] - } - /** * Returns an expression that can be used to deserialize an input row to an object of type `T` * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 789750fd408f..346b324121bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -26,9 +26,9 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * A factory for constructing encoders that convert external row to/from the Spark SQL @@ -235,6 +235,26 @@ object RowEncoder { case udt: UserDefinedType[_] => ObjectType(udt.userClass) } + // Returns the runtime class corresponding to the provided external type that + // is retrieved by `externalDataTypeForInput`. Note that `PythonUserDefinedType` and + // `UserDefinedType` are converted into native types or `ObjectType`s in `externalDataTypeFor`, + // so this method can handle both types correctly. + def getClassFromExternalType(externalType: DataType): Class[_] = externalType match { + case NullType => classOf[Object] + case BooleanType => classOf[java.lang.Boolean] + case ByteType => classOf[java.lang.Byte] + case ShortType => classOf[java.lang.Short] + case IntegerType => classOf[java.lang.Integer] + case LongType => classOf[java.lang.Long] + case FloatType => classOf[java.lang.Float] + case DoubleType => classOf[java.lang.Double] + case BinaryType => classOf[Array[Byte]] + case CalendarIntervalType => classOf[CalendarInterval] + // External types for the other types (e.g., array, map, and struct) + // must be `ObjectType`. + case ObjectType(cls) => cls + } + private def deserializerFor(schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => val dt = f.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index e3f8393da858..51145648534f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1674,7 +1674,6 @@ case class ValidateExternalType(child: Expression, expected: DataType) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" - private lazy val checkType: (Any) => Boolean = expected match { case _: DecimalType => (value: Any) => { @@ -1686,13 +1685,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) value.getClass.isArray || value.isInstanceOf[Seq[_]] } case _ => - val dataTypeClazz = if (dataType.isInstanceOf[ObjectType]) { - dataType.asInstanceOf[ObjectType].cls - } else { - // Some external types (e.g., native types and `PythonUserDefinedType`) - // might not be ObjectType - ScalaReflection.classForNativeTypeOf(dataType) - } + val dataTypeClazz = RowEncoder.getClassFromExternalType(dataType) (value: Any) => { dataTypeClazz.isInstance(value) } From 8cec38288f1ba0a48129693e9da9573448cf91b1 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 12 Mar 2018 15:10:21 +0900 Subject: [PATCH 5/7] Fix --- .../sql/catalyst/encoders/RowEncoder.scala | 22 +---------------- .../expressions/objects/objects.scala | 24 ++++++++++++++++++- .../expressions/ObjectExpressionsSuite.scala | 1 - 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 346b324121bc..3340789398f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.UTF8String /** * A factory for constructing encoders that convert external row to/from the Spark SQL @@ -235,26 +235,6 @@ object RowEncoder { case udt: UserDefinedType[_] => ObjectType(udt.userClass) } - // Returns the runtime class corresponding to the provided external type that - // is retrieved by `externalDataTypeForInput`. Note that `PythonUserDefinedType` and - // `UserDefinedType` are converted into native types or `ObjectType`s in `externalDataTypeFor`, - // so this method can handle both types correctly. - def getClassFromExternalType(externalType: DataType): Class[_] = externalType match { - case NullType => classOf[Object] - case BooleanType => classOf[java.lang.Boolean] - case ByteType => classOf[java.lang.Byte] - case ShortType => classOf[java.lang.Short] - case IntegerType => classOf[java.lang.Integer] - case LongType => classOf[java.lang.Long] - case FloatType => classOf[java.lang.Float] - case DoubleType => classOf[java.lang.Double] - case BinaryType => classOf[Array[Byte]] - case CalendarIntervalType => classOf[CalendarInterval] - // External types for the other types (e.g., array, map, and struct) - // must be `ObjectType`. - case ObjectType(cls) => cls - } - private def deserializerFor(schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => val dt = f.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 51145648534f..bf25b0ad567b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils /** @@ -1674,6 +1675,27 @@ case class ValidateExternalType(child: Expression, expected: DataType) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + // This function is corresponding to `CodeGenerator.boxedType` + private def getClassFromDataType(dataType: DataType): Class[_] = dataType match { + case BooleanType => classOf[java.lang.Boolean] + case ByteType => classOf[java.lang.Byte] + case ShortType => classOf[java.lang.Short] + case IntegerType | DateType => classOf[java.lang.Integer] + case LongType | TimestampType => classOf[java.lang.Long] + case FloatType => classOf[java.lang.Float] + case DoubleType => classOf[java.lang.Double] + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayType] + case _: MapType => classOf[MapType] + case udt: UserDefinedType[_] => getClassFromDataType(udt.sqlType) + case ObjectType(cls) => cls + case _ => classOf[Object] + } + private lazy val checkType: (Any) => Boolean = expected match { case _: DecimalType => (value: Any) => { @@ -1685,7 +1707,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) value.getClass.isArray || value.isInstanceOf[Seq[_]] } case _ => - val dataTypeClazz = RowEncoder.getClassFromExternalType(dataType) + val dataTypeClazz = getClassFromDataType(dataType) (value: Any) => { dataTypeClazz.isInstance(value) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index ca6c4c3a285f..7136af893448 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -274,7 +274,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq())) } - // This is an alternative version of `checkEvaluation` to compare results // by scala values instead of catalyst values. private def checkObjectExprEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { From bede10199ca6aef24c27f5222e0699d055a18069 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 10 Apr 2018 12:29:55 +0900 Subject: [PATCH 6/7] Use ScalaReflection.typeBoxedJavaMapping --- .../catalyst/expressions/objects/objects.scala | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index bf25b0ad567b..9a4adaaf293b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1676,14 +1676,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" // This function is corresponding to `CodeGenerator.boxedType` - private def getClassFromDataType(dataType: DataType): Class[_] = dataType match { - case BooleanType => classOf[java.lang.Boolean] - case ByteType => classOf[java.lang.Byte] - case ShortType => classOf[java.lang.Short] - case IntegerType | DateType => classOf[java.lang.Integer] - case LongType | TimestampType => classOf[java.lang.Long] - case FloatType => classOf[java.lang.Float] - case DoubleType => classOf[java.lang.Double] + private def boxedType(dt: DataType): Class[_] = dataType match { case _: DecimalType => classOf[Decimal] case BinaryType => classOf[Array[Byte]] case StringType => classOf[UTF8String] @@ -1691,9 +1684,9 @@ case class ValidateExternalType(child: Expression, expected: DataType) case _: StructType => classOf[InternalRow] case _: ArrayType => classOf[ArrayType] case _: MapType => classOf[MapType] - case udt: UserDefinedType[_] => getClassFromDataType(udt.sqlType) + case udt: UserDefinedType[_] => boxedType(udt.sqlType) case ObjectType(cls) => cls - case _ => classOf[Object] + case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) } private lazy val checkType: (Any) => Boolean = expected match { @@ -1707,7 +1700,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) value.getClass.isArray || value.isInstanceOf[Seq[_]] } case _ => - val dataTypeClazz = getClassFromDataType(dataType) + val dataTypeClazz = boxedType(dataType) (value: Any) => { dataTypeClazz.isInstance(value) } From 747481117136b08c37b8458b19e9157c931d6bb3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 20 Apr 2018 07:43:54 +0900 Subject: [PATCH 7/7] Fix --- .../spark/sql/catalyst/ScalaReflection.scala | 13 +++++++++++++ .../catalyst/expressions/objects/objects.scala | 16 +--------------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 818cc2fb1e8a..f9acc208b715 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -846,6 +846,19 @@ object ScalaReflection extends ScalaReflection { } } + def javaBoxedType(dt: DataType): Class[_] = dt match { + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayType] + case _: MapType => classOf[MapType] + case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType) + case ObjectType(cls) => cls + case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { if (arguments != Nil) { arguments.map(e => dataTypeJavaClass(e.dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9a4adaaf293b..5b4800bc6fa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1675,20 +1675,6 @@ case class ValidateExternalType(child: Expression, expected: DataType) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" - // This function is corresponding to `CodeGenerator.boxedType` - private def boxedType(dt: DataType): Class[_] = dataType match { - case _: DecimalType => classOf[Decimal] - case BinaryType => classOf[Array[Byte]] - case StringType => classOf[UTF8String] - case CalendarIntervalType => classOf[CalendarInterval] - case _: StructType => classOf[InternalRow] - case _: ArrayType => classOf[ArrayType] - case _: MapType => classOf[MapType] - case udt: UserDefinedType[_] => boxedType(udt.sqlType) - case ObjectType(cls) => cls - case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) - } - private lazy val checkType: (Any) => Boolean = expected match { case _: DecimalType => (value: Any) => { @@ -1700,7 +1686,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) value.getClass.isArray || value.isInstanceOf[Seq[_]] } case _ => - val dataTypeClazz = boxedType(dataType) + val dataTypeClazz = ScalaReflection.javaBoxedType(dataType) (value: Any) => { dataTypeClazz.isInstance(value) }