Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix
  • Loading branch information
maropu committed Apr 19, 2018
commit 8cec38288f1ba0a48129693e9da9573448cf91b1
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.UTF8String

/**
* A factory for constructing encoders that convert external row to/from the Spark SQL
Expand Down Expand Up @@ -235,26 +235,6 @@ object RowEncoder {
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
}

// Returns the runtime class corresponding to the provided external type that
// is retrieved by `externalDataTypeForInput`. Note that `PythonUserDefinedType` and
// `UserDefinedType` are converted into native types or `ObjectType`s in `externalDataTypeFor`,
// so this method can handle both types correctly.
def getClassFromExternalType(externalType: DataType): Class[_] = externalType match {
case NullType => classOf[Object]
case BooleanType => classOf[java.lang.Boolean]
case ByteType => classOf[java.lang.Byte]
case ShortType => classOf[java.lang.Short]
case IntegerType => classOf[java.lang.Integer]
case LongType => classOf[java.lang.Long]
case FloatType => classOf[java.lang.Float]
case DoubleType => classOf[java.lang.Double]
case BinaryType => classOf[Array[Byte]]
case CalendarIntervalType => classOf[CalendarInterval]
// External types for the other types (e.g., array, map, and struct)
// must be `ObjectType`.
case ObjectType(cls) => cls
}

private def deserializerFor(schema: StructType): Expression = {
val fields = schema.zipWithIndex.map { case (f, i) =>
val dt = f.dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -1674,6 +1675,27 @@ case class ValidateExternalType(child: Expression, expected: DataType)

private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}"

// This function is corresponding to `CodeGenerator.boxedType`
private def getClassFromDataType(dataType: DataType): Class[_] = dataType match {
case BooleanType => classOf[java.lang.Boolean]
case ByteType => classOf[java.lang.Byte]
case ShortType => classOf[java.lang.Short]
case IntegerType | DateType => classOf[java.lang.Integer]
case LongType | TimestampType => classOf[java.lang.Long]
case FloatType => classOf[java.lang.Float]
case DoubleType => classOf[java.lang.Double]
case _: DecimalType => classOf[Decimal]
case BinaryType => classOf[Array[Byte]]
case StringType => classOf[UTF8String]
case CalendarIntervalType => classOf[CalendarInterval]
case _: StructType => classOf[InternalRow]
case _: ArrayType => classOf[ArrayType]
case _: MapType => classOf[MapType]
case udt: UserDefinedType[_] => getClassFromDataType(udt.sqlType)
case ObjectType(cls) => cls
case _ => classOf[Object]
}

private lazy val checkType: (Any) => Boolean = expected match {
case _: DecimalType =>
(value: Any) => {
Expand All @@ -1685,7 +1707,7 @@ case class ValidateExternalType(child: Expression, expected: DataType)
value.getClass.isArray || value.isInstanceOf[Seq[_]]
}
case _ =>
val dataTypeClazz = RowEncoder.getClassFromExternalType(dataType)
val dataTypeClazz = getClassFromDataType(dataType)
(value: Any) => {
dataTypeClazz.isInstance(value)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq()))
}

// This is an alternative version of `checkEvaluation` to compare results
// by scala values instead of catalyst values.
private def checkObjectExprEvaluation(
expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
Expand Down