Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ object Encoders {
*/
def STRING: Encoder[java.lang.String] = ExpressionEncoder()

def bean[T](beanCls: Class[T]): Encoder[T] = ExpressionEncoder(beanCls)

/**
* (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
* This encoder maps T into a single byte array (binary) field.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@

package org.apache.spark.sql.catalyst

import java.beans.Introspector
import java.beans.{PropertyDescriptor, Introspector}
import java.lang.{Iterable => JIterable}
import java.util.{Iterator => JIterator, Map => JMap}
import java.util.{Iterator => JIterator, Map => JMap, List => JList}

import scala.language.existentials

import com.google.common.reflect.TypeToken

import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.unsafe.types.UTF8String


/**
* Type-inference utilities for POJOs and Java collections.
Expand All @@ -33,13 +39,14 @@ object JavaTypeInference {

private val iterableType = TypeToken.of(classOf[JIterable[_]])
private val mapType = TypeToken.of(classOf[JMap[_, _]])
private val listType = TypeToken.of(classOf[JList[_]])
private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType

/**
* Infers the corresponding SQL data type of a JavaClean class.
* Infers the corresponding SQL data type of a JavaBean class.
* @param beanClass Java type
* @return (SQL data type, nullable)
*/
Expand All @@ -58,6 +65,8 @@ object JavaTypeInference {
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true)

case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
Expand Down Expand Up @@ -87,31 +96,293 @@ object JavaTypeInference {
(ArrayType(dataType, nullable), true)

case _ if mapType.isAssignableFrom(typeToken) =>
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
val (keyType, valueType) = mapKeyValueType(typeToken)
val (keyDataType, _) = inferDataType(keyType)
val (valueDataType, nullable) = inferDataType(valueType)
(MapType(keyDataType, valueDataType, nullable), true)

case _ =>
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val fields = properties.map { property =>
val returnType = typeToken.method(property.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(returnType)
new StructField(property.getName, dataType, nullable)
case other =>
val properties = getJavaBeanProperties(other)
if (properties.length > 0) {
val fields = properties.map { property =>
val returnType = typeToken.method(property.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(returnType)
new StructField(property.getName, dataType, nullable)
}
(new StructType(fields), true)
} else {
throw new UnsupportedOperationException(s"Cannot infer data type for ${other.getName}")
}
(new StructType(fields), true)
}
}

private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
beanInfo.getPropertyDescriptors
.filter(p => p.getReadMethod != null && p.getWriteMethod != null)
}

private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]])
val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
val itemType = iteratorType.resolveType(nextReturnType)
itemType
val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]])
val iteratorType = iterableSuperType.resolveType(iteratorReturnType)
iteratorType.resolveType(nextReturnType)
}

private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = {
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]])
val keyType = elementType(mapSuperType.resolveType(keySetReturnType))
val valueType = elementType(mapSuperType.resolveType(valuesReturnType))
keyType -> valueType
}

private def inferExternalType(cls: Class[_]): DataType = cls match {
case c if c == java.lang.Boolean.TYPE => BooleanType
case c if c == java.lang.Byte.TYPE => ByteType
case c if c == java.lang.Short.TYPE => ShortType
case c if c == java.lang.Integer.TYPE => IntegerType
case c if c == java.lang.Long.TYPE => LongType
case c if c == java.lang.Float.TYPE => FloatType
case c if c == java.lang.Double.TYPE => DoubleType
case c if c == classOf[Array[Byte]] => BinaryType
case _ => ObjectType(cls)
}

def constructorFor(beanClass: Class[_]): Expression = {
constructorFor(TypeToken.of(beanClass), None)
}

private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
/** Returns the current path with a sub-field extracted. */
def addToPath(part: String): Expression = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute(part))

/** Returns the current path or `BoundReference`. */
def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true))

typeToken.getRawType match {
case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath

case c if c == classOf[java.lang.Short] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
case c if c == classOf[java.lang.Integer] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
case c if c == classOf[java.lang.Long] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
case c if c == classOf[java.lang.Double] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
case c if c == classOf[java.lang.Byte] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
case c if c == classOf[java.lang.Float] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
case c if c == classOf[java.lang.Boolean] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))

case c if c == classOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils,
ObjectType(c),
"toJavaDate",
getPath :: Nil,
propagateNull = true)

case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils,
ObjectType(c),
"toJavaTimestamp",
getPath :: Nil,
propagateNull = true)

case c if c == classOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]))

case c if c == classOf[java.math.BigDecimal] =>
Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))

case c if c.isArray =>
val elementType = c.getComponentType
val primitiveMethod = elementType match {
case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray")
case c if c == java.lang.Byte.TYPE => Some("toByteArray")
case c if c == java.lang.Short.TYPE => Some("toShortArray")
case c if c == java.lang.Integer.TYPE => Some("toIntArray")
case c if c == java.lang.Long.TYPE => Some("toLongArray")
case c if c == java.lang.Float.TYPE => Some("toFloatArray")
case c if c == java.lang.Double.TYPE => Some("toDoubleArray")
case _ => None
}

primitiveMethod.map { method =>
Invoke(getPath, method, ObjectType(c))
}.getOrElse {
Invoke(
MapObjects(
p => constructorFor(typeToken.getComponentType, Some(p)),
getPath,
inferDataType(elementType)._1),
"array",
ObjectType(c))
}

case c if listType.isAssignableFrom(typeToken) =>
val et = elementType(typeToken)
val array =
Invoke(
MapObjects(
p => constructorFor(et, Some(p)),
getPath,
inferDataType(et)._1),
"array",
ObjectType(classOf[Array[Any]]))

StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil)

case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)
val keyDataType = inferDataType(keyType)._1
val valueDataType = inferDataType(valueType)._1

val keyData =
Invoke(
MapObjects(
p => constructorFor(keyType, Some(p)),
Invoke(getPath, "keyArray", ArrayType(keyDataType)),
keyDataType),
"array",
ObjectType(classOf[Array[Any]]))

val valueData =
Invoke(
MapObjects(
p => constructorFor(valueType, Some(p)),
Invoke(getPath, "valueArray", ArrayType(valueDataType)),
valueDataType),
"array",
ObjectType(classOf[Array[Any]]))

StaticInvoke(
ArrayBasedMapData,
ObjectType(classOf[JMap[_, _]]),
"toJavaMap",
keyData :: valueData :: Nil)

case other =>
val properties = getJavaBeanProperties(other)
assert(properties.length > 0)

val setters = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName)))
}.toMap

val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other))
val result = InitializeJavaBean(newInstance, setters)

if (path.nonEmpty) {
expressions.If(
IsNull(getPath),
expressions.Literal.create(null, ObjectType(other)),
result
)
} else {
result
}
}
}

def extractorsFor(inputObject: Expression, beanClass: Class[_]): CreateNamedStruct = {
extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
}

private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {

def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
val externalType = inferExternalType(elementType.getRawType)
val (dataType, nullable) = inferDataType(elementType)
if (ScalaReflection.isNativeType(dataType)) {
NewInstance(
classOf[GenericArrayData],
input :: Nil,
dataType = ArrayType(dataType, nullable))
} else {
MapObjects(extractorFor(_, elementType), input, externalType)
}
}

if (!inputObject.dataType.isInstanceOf[ObjectType]) {
inputObject
} else {
typeToken.getRawType match {
case c if c == classOf[String] =>
StaticInvoke(
classOf[UTF8String],
StringType,
"fromString",
inputObject :: Nil)

case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils,
TimestampType,
"fromJavaTimestamp",
inputObject :: Nil)

case c if c == classOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils,
DateType,
"fromJavaDate",
inputObject :: Nil)

case c if c == classOf[java.math.BigDecimal] =>
StaticInvoke(
Decimal,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)

case c if c == classOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)
case c if c == classOf[java.lang.Byte] =>
Invoke(inputObject, "byteValue", ByteType)
case c if c == classOf[java.lang.Short] =>
Invoke(inputObject, "shortValue", ShortType)
case c if c == classOf[java.lang.Integer] =>
Invoke(inputObject, "intValue", IntegerType)
case c if c == classOf[java.lang.Long] =>
Invoke(inputObject, "longValue", LongType)
case c if c == classOf[java.lang.Float] =>
Invoke(inputObject, "floatValue", FloatType)
case c if c == classOf[java.lang.Double] =>
Invoke(inputObject, "doubleValue", DoubleType)

case _ if typeToken.isArray =>
toCatalystArray(inputObject, typeToken.getComponentType)

case _ if listType.isAssignableFrom(typeToken) =>
toCatalystArray(inputObject, elementType(typeToken))

case _ if mapType.isAssignableFrom(typeToken) =>
throw new UnsupportedOperationException("map type is not supported currently")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that, for java map, if we get the keys and values by keySet and values, we can not guarantee they have same iteration order(which is different from scala map). A possible solution is creating a new MapObjects that can iterate a map directly.

cc @marmbrus


case other =>
val properties = getJavaBeanProperties(other)
assert(properties.length > 0)

CreateNamedStruct(properties.flatMap { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
})
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand Down
Loading