Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address comments
  • Loading branch information
cloud-fan committed Dec 1, 2015
commit 4abac6865cc3f6667c6053a384661e2674a26407
11 changes: 10 additions & 1 deletion sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,18 @@ object Encoders {
*
* T must be publicly accessible.
*
* supported types for java bean field:
* - primitive types: boolean, int, double, etc.
* - boxed types: Boolean, Integer, Double, etc.
* - String
* - java.math.BigDecimal
* - time related: java.sql.Date, java.sql.Timestamp
* - collection types: only array and java.util.List currently, map support is in progress
* - nested java bean.
*
* @since 1.6.0
*/
def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder(beanClass)
def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass)

/**
* (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ object JavaTypeInference {
toCatalystArray(inputObject, elementType(typeToken))

case _ if mapType.isAssignableFrom(typeToken) =>
// TODO: for java map, if we get the keys and values by `keySet` and `values`, we can
// not guarantee they have same iteration order(which is different from scala map).
// A possible solution is creating a new `MapObjects` that can iterate a map directly.
throw new UnsupportedOperationException("map type is not supported currently")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that, for java map, if we get the keys and values by keySet and values, we can not guarantee they have same iteration order(which is different from scala map). A possible solution is creating a new MapObjects that can iterate a map directly.

cc @marmbrus


case other =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object ExpressionEncoder {
}

// TODO: improve error message for java bean encoder.
def apply[T](beanClass: Class[T]): ExpressionEncoder[T] = {
def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
val schema = JavaTypeInference.inferDataType(beanClass)._1
assert(schema.isInstanceOf[StructType])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,34 +603,34 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
/**
* Initialize a Java Bean instance by setting its field values via setters.
*/
case class InitializeJavaBean(n: NewInstance, setters: Map[String, Expression])
case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression])
extends Expression {

override def nullable: Boolean = n.nullable
override def children: Seq[Expression] = n +: setters.values.toSeq
override def dataType: DataType = n.dataType
override def nullable: Boolean = beanInstance.nullable
override def children: Seq[Expression] = beanInstance +: setters.values.toSeq
override def dataType: DataType = beanInstance.dataType

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val instance = n.gen(ctx)
val instanceGen = beanInstance.gen(ctx)

val initialize = setters.map {
case (setterMethod, fieldValue) =>
val fieldGen = fieldValue.gen(ctx)
s"""
${fieldGen.code}
${instance.value}.$setterMethod(${fieldGen.value});
${instanceGen.value}.$setterMethod(${fieldGen.value});
"""
}

ev.isNull = instance.isNull
ev.value = instance.value
ev.isNull = instanceGen.isNull
ev.value = instanceGen.value

s"""
${instance.code}
if (!${instance.isNull}) {
${instanceGen.code}
if (!${instanceGen.isNull}) {
${initialize.mkString("\n")}
}
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
case nonChild: AnyRef => nonChild
case null => null
}.view.force
}.view.force // `mapValues` is lazy and we need to force it to materialize
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
Expand Down Expand Up @@ -287,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
arg
}
case other => other
}.view.force
}.view.force // `mapValues` is lazy and we need to force it to materialize
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if containsChild(arg) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]])
override def output: Seq[Attribute] = Nil
}

case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable {
override def children: Seq[Expression] = map.values.toSeq
override def nullable: Boolean = true
override def dataType: NullType = NullType
override lazy val resolved = true
}

class TreeNodeSuite extends SparkFunSuite {
test("top node changed") {
val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
Expand Down Expand Up @@ -236,4 +243,22 @@ class TreeNodeSuite extends SparkFunSuite {
val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2"))))
assert(expected === actual)
}

test("expressions inside a map") {
val expression = ExpressionInMap(Map("1" -> Literal(1), "2" -> Literal(2)))

{
val actual = expression.transform {
case Literal(i: Int, _) => Literal(i + 1)
}
val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3)))
assert(actual === expected)
}

{
val actual = expression.withNewChildren(Seq(Literal(2), Literal(3)))
val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3)))
assert(actual === expected)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.function.*;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.GroupedDataset;
import org.apache.spark.sql.*;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.types.StructType;

import static org.apache.spark.sql.functions.*;
import static org.apache.spark.sql.types.DataTypes.*;

public class JavaDatasetSuite implements Serializable {
private transient JavaSparkContext jsc;
Expand Down Expand Up @@ -613,14 +613,14 @@ public void testJavaBeanEncoder() {
SimpleJavaBean obj1 = new SimpleJavaBean();
obj1.setA(true);
obj1.setB(3);
obj1.setC(new byte[]{1});
obj1.setC(new byte[]{1, 2});
obj1.setD(new String[]{"hello"});
obj1.setE(Arrays.asList("a", "b"));
SimpleJavaBean obj2 = new SimpleJavaBean();
obj2.setA(false);
obj2.setB(30);
obj2.setC(new byte[]{2});
obj1.setD(new String[]{"world"});
obj2.setC(new byte[]{3, 4});
obj2.setD(new String[]{"world"});
obj2.setE(Arrays.asList("x", "y"));

List<SimpleJavaBean> data = Arrays.asList(obj1, obj2);
Expand All @@ -633,5 +633,27 @@ public void testJavaBeanEncoder() {
List<NestedJavaBean> data2 = Arrays.asList(obj3);
Dataset<NestedJavaBean> ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class));
Assert.assertEquals(data2, ds2.collectAsList());

Row row1 = new GenericRow(new Object[]{
true,
3,
new byte[]{1, 2},
new String[]{"hello"},
Arrays.asList("a", "b")});
Row row2 = new GenericRow(new Object[]{
false,
30,
new byte[]{3, 4},
new String[]{"world"},
Arrays.asList("x", "y")});
StructType schema = new StructType()
.add("a", BooleanType, false)
.add("b", IntegerType, false)
.add("c", BinaryType)
.add("d", createArrayType(StringType))
.add("e", createArrayType(StringType));
Dataset<SimpleJavaBean> ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema)
.as(Encoders.bean(SimpleJavaBean.class));
Assert.assertEquals(data, ds3.collectAsList());
}
}