From 7504c6790bf9bad143bce9f259e1ce98a5b40043 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 11 Nov 2015 21:43:07 +0800 Subject: [PATCH 1/4] add reduce to GroupedDataset --- .../org/apache/spark/sql/GroupedDataset.scala | 21 +++++++++++++------ .../apache/spark/sql/JavaDatasetSuite.java | 11 ++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 9 ++++++++ 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index db6149922928..183e4d520d03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql -import java.util.{Iterator => JIterator} - import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _} +import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} @@ -127,15 +125,26 @@ class GroupedDataset[K, T] private[sql]( */ def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = { val func = (key: K, it: Iterator[T]) => Iterator(f(key, it)) - new Dataset[U]( - sqlContext, - MapGroups(func, groupingAttributes, logicalPlan)) + flatMap(func) } def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { map((key, data) => f.call(key, data.asJava))(encoder) } + /** + * Reduces the elements of each group of data using the specified binary function. + * The given function must be commutative and associative or the result may be non-deterministic. + */ + def reduce(f: (T, T) => T): Dataset[(K, T)] = { + val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f)) + flatMap(func)(ExpressionEncoder.tuple(kEnc, tEnc)) + } + + def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = { + reduce(f.call _) + } + // To ensure valid overloading. protected def agg(expr: Column, exprs: Column*): DataFrame = groupedData.agg(expr, exprs: _*) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 2da63d1b9670..8aadf41c76a7 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -197,6 +197,17 @@ public Iterable call(Integer key, Iterator values) throws Except Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); + Dataset> reduced = grouped.reduce(new ReduceFunction() { + @Override + public String call(String v1, String v2) throws Exception { + return v1 + v2; + } + }); + + Assert.assertEquals( + Arrays.asList(tuple2(1, "a"), tuple2(3, "foobar")), + reduced.collectAsList()); + List data2 = Arrays.asList(2, 6, 10); Dataset ds2 = context.createDataset(data2, e.INT()); GroupedDataset grouped2 = ds2.groupBy(new MapFunction() { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 621148528714..449bab1b23d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -218,6 +218,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { "a", "30", "b", "3", "c", "1") } + test("groupBy function, reduce") { + val ds = Seq("abc", "xzy", "hello").toDS() + val agged = ds.groupBy(_.length).reduce(_ + _) + + checkAnswer( + agged, + 3 -> "abcxyz", 5 -> "hello") + } + test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") From e58a0c91d90ee3780939a9a3e9fd37b11f7557a2 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 11 Nov 2015 18:18:11 -0800 Subject: [PATCH 2/4] Working finally --- .../catalyst/analysis/HiveTypeCoercion.scala | 1 + .../catalyst/encoders/ExpressionEncoder.scala | 62 ++++++++++++-- .../expressions/complexTypeExtractors.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 24 +++++- .../org/apache/spark/sql/DataFrame.scala | 17 +--- .../scala/org/apache/spark/sql/Dataset.scala | 50 ++++------- .../org/apache/spark/sql/GroupedDataset.scala | 46 ++++------- .../spark/sql/execution/Queryable.scala | 21 +++++ .../aggregate/TypedAggregateExpression.scala | 3 +- .../apache/spark/sql/JavaDatasetSuite.java | 1 - .../spark/sql/DatasetAggregatorSuite.scala | 82 +++++++++---------- .../org/apache/spark/sql/DatasetSuite.scala | 2 +- .../org/apache/spark/sql/QueryTest.scala | 7 +- 13 files changed, 178 insertions(+), 140 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index bf2bff0243fa..2d8060086003 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -712,6 +712,7 @@ object HiveTypeCoercion { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => + println(s"$in -> $expected") // If we cannot do the implicit cast, just use the original input. implicitCast(in, expected).getOrElse(in) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 294afde5347e..1e9adea33202 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{StructField, ObjectType, StructType} +import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType} /** * A factory for constructing encoders that convert objects and primitves to and from the @@ -67,14 +67,41 @@ object ExpressionEncoder { def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { val schema = StructType( - encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)}) + encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)}) val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val extractExpressions = encoders.map { - case e if e.flat => e.extractExpressions.head - case other => CreateStruct(other.extractExpressions) + + // Rebind the encoders to the nested schema that will be produced by the aggregation. + val newConstructExpressions = encoders.zipWithIndex.map { + case (e, i) if !e.flat => + println(s"=== $i - nested ===") + println(e.constructExpression.treeString) + println() + println(e.nested(i).constructExpression.treeString) + + e.nested(i).constructExpression + case (e, i) => + println(s"=== $i - flat ===") + println(e.constructExpression.treeString) + println() + println(e.shift(i).constructExpression.treeString) + + e.shift(i).constructExpression } + val constructExpression = - NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls)) + NewInstance(cls, newConstructExpressions, false, ObjectType(cls)) + + val input = BoundReference(0, ObjectType(cls), false) + val extractExpressions = encoders.zipWithIndex.map { + case (e, i) if !e.flat => CreateStruct(e.extractExpressions.map(_ transformUp { + case b: BoundReference => + Invoke(input, s"_${i + 1}", b.dataType, Nil) + })) + case (e, i) => e.extractExpressions.head transformUp { + case b: BoundReference => + Invoke(input, s"_${i + 1}", b.dataType, Nil) + } + } new ExpressionEncoder[Any]( schema, @@ -121,9 +148,12 @@ case class ExpressionEncoder[T]( * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should * copy the result before making another call if required. */ - def toRow(t: T): InternalRow = { + def toRow(t: T): InternalRow = try { inputRow(0) = t extractProjection(inputRow) + } catch { + case e: Exception => + throw new RuntimeException(s"Error while encoding: $e\n${extractExpressions.map(_.treeString).mkString("\n")}", e) } /** @@ -191,6 +221,24 @@ case class ExpressionEncoder[T]( }) } + /** + * Returns a copy of this encoder where the expressions used to create an object given an + * input row have been modified to pull the object out from a nested struct, instead of the + * top level fields. + */ + def nested(i: Int): ExpressionEncoder[T] = { + val input = BoundReference(i, NullType, true) + copy(constructExpression = constructExpression transformUp { + case u: Attribute => + UnresolvedExtractValue(input, Literal(u.name)) + case b: BoundReference => + GetStructField( + input, + StructField(s"i[${b.ordinal}]", b.dataType), + b.ordinal) + }) + } + /** * Returns a copy of this encoder where the expressions used to create an object given an * input row have been modified to pull the object out from a nested struct, instead of the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 41cd0a104a1f..96652b65fa0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -101,7 +101,7 @@ object ExtractValue { case class GetStructField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression { - override def dataType: DataType = field.dataType + override def dataType: DataType = child.dataType.asInstanceOf[StructType](ordinal).dataType override def nullable: Boolean = child.nullable || field.nullable override def toString: String = s"$child.${field.name}" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index f0f275e91f1a..06d9061be0c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression + import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ @@ -45,7 +47,25 @@ private[sql] object Column { * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). * @tparam U The output type of this column. */ -class TypedColumn[-T, U](expr: Expression, val encoder: Encoder[U]) extends Column(expr) +class TypedColumn[-T, U]( + expr: Expression, + private[sql] val encoder: ExpressionEncoder[U]) extends Column(expr) { + + /** + * Inserts the specific input type and schema into any expressions that are expected to operate + * on a decoded object. + */ + private[sql] def withInputType( + inputEncoder: ExpressionEncoder[_], + schema: Seq[Attribute]): TypedColumn[T, U] = { + new TypedColumn[T, U] (expr transform { + case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy( + aEncoder = Some(inputEncoder.asInstanceOf[ExpressionEncoder[Any]]), + children = schema) + }, encoder) + } +} /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index a492099b9392..553c10757cce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -735,22 +735,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - val namedExpressions = cols.map { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) - - case Column(expr: NamedExpression) => expr - - // Leave an unaliased generator with an empty list of names since the analyzer will generate - // the correct defaults after the nested expression's type has been resolved. - case Column(explode: Explode) => MultiAlias(explode, Nil) - case Column(jt: JsonTuple) => MultiAlias(jt, Nil) - - case Column(expr: Expression) => Alias(expr, expr.prettyString)() - } - Project(namedExpressions.toSeq, logicalPlan) + Project(nameColumns(cols), logicalPlan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 87dae6b33159..95bb4e4ad344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -63,11 +63,12 @@ import org.apache.spark.sql.types.StructType class Dataset[T] private[sql]( @transient val sqlContext: SQLContext, @transient val queryExecution: QueryExecution, - unresolvedEncoder: Encoder[T]) extends Queryable with Serializable { + implicit val unresolvedEncoder: Encoder[T]) extends Queryable with Serializable { /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ - private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match { - case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output) + private[sql] val encoder: ExpressionEncoder[T] = unresolvedEncoder match { + case e: ExpressionEncoder[T] => + e.unbind(queryExecution.analyzed.output).resolve(queryExecution.analyzed.output) case _ => throw new IllegalArgumentException("Only expression encoders are currently supported") } @@ -134,7 +135,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def rdd: RDD[T] = { - val tEnc = encoderFor[T] + val tEnc = encoder val input = queryExecution.analyzed.output queryExecution.toRdd.mapPartitions { iter => val bound = tEnc.bind(input) @@ -195,7 +196,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { - new Dataset( + new Dataset[U]( sqlContext, MapPartitions[T, U]( func, @@ -360,7 +361,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { - new Dataset[U1](sqlContext, Project(Alias(withEncoder(c1).expr, "_1")() :: Nil, logicalPlan)) + new Dataset[U1](sqlContext, Project(nameColumns(c1.withInputType(encoder, queryExecution.analyzed.output) :: Nil), logicalPlan)) } /** @@ -369,28 +370,11 @@ class Dataset[T] private[sql]( * that cast appropriately for the user facing interface. */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val withEncoders = columns.map(withEncoder) - val aliases = withEncoders.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } - val unresolvedPlan = Project(aliases, logicalPlan) - val execution = new QueryExecution(sqlContext, unresolvedPlan) - // Rebind the encoders to the nested schema that will be produced by the select. - val encoders = withEncoders.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { - case (e: ExpressionEncoder[_], a) if !e.flat => - e.nested(a.toAttribute).resolve(execution.analyzed.output) - case (e, a) => - e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output) - } - new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) - } + val encoders = columns.map(_.encoder) + val namedColumns = nameColumns(columns.map(_.withInputType(encoder, queryExecution.analyzed.output))) + val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) - private def withEncoder(c: TypedColumn[_, _]): TypedColumn[_, _] = { - val e = c.expr transform { - case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => - ta.copy( - aEncoder = Some(encoder.asInstanceOf[ExpressionEncoder[Any]]), - children = queryExecution.analyzed.output) - } - new TypedColumn(e, c.encoder) + new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) } /** @@ -505,15 +489,9 @@ class Dataset[T] private[sql]( case e if e.flat => Alias(right.output.head, "_2")() case _ => Alias(CreateStruct(right.output), "_2")() } - val leftEncoder = - if (encoder.flat) encoder else encoder.nested(leftData.toAttribute) - val rightEncoder = - if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute) - implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple( - leftEncoder, - rightEncoder.rebind(right.output, left.output ++ right.output)) + + implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(unresolvedEncoder.asInstanceOf[ExpressionEncoder[T]], other.unresolvedEncoder.asInstanceOf[ExpressionEncoder[U]]) withPlan[(T, U)](other) { (left, right) => Project( leftData :: rightData :: Nil, @@ -580,7 +558,7 @@ class Dataset[T] private[sql]( private[sql] def logicalPlan = queryExecution.analyzed private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = - new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder) + new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), unresolvedEncoder) private[sql] def withPlan[R : Encoder]( other: Dataset[_])( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 37f89176a2a4..a4ddb705ec57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.types.StructType + import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental @@ -26,7 +28,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{Queryable, QueryExecution} /** @@ -44,9 +46,11 @@ import org.apache.spark.sql.execution.QueryExecution class GroupedDataset[K, T] private[sql]( private val kEncoder: Encoder[K], private val tEncoder: Encoder[T], - queryExecution: QueryExecution, + val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], - private val groupingAttributes: Seq[Attribute]) extends Serializable { + private val groupingAttributes: Seq[Attribute]) extends Queryable with Serializable { + + def schema: StructType = ??? private implicit val kEnc = kEncoder match { case e: ExpressionEncoder[K] => e.unbind(groupingAttributes).resolve(groupingAttributes) @@ -138,7 +142,9 @@ class GroupedDataset[K, T] private[sql]( */ def reduce(f: (T, T) => T): Dataset[(K, T)] = { val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f)) - flatMap(func)(ExpressionEncoder.tuple(kEnc, tEnc)) + println(ExpressionEncoder.tuple(kEnc, tEnc).constructExpression.treeString) + + flatMap(func)(ExpressionEncoder.tuple(kEncoder.asInstanceOf[ExpressionEncoder[K]], tEncoder.asInstanceOf[ExpressionEncoder[T]])) } def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = { @@ -156,37 +162,15 @@ class GroupedDataset[K, T] private[sql]( * TODO: does not handle aggrecations that return nonflat results, */ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val aliases = (groupingAttributes ++ columns.map(_.expr)).map { - case u: UnresolvedAttribute => UnresolvedAlias(u) - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - - val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan) - - // Fill in the input encoders for any aggregators in the plan. - val withEncoders = unresolvedPlan transformAllExpressions { - case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => - ta.copy( - aEncoder = Some(tEnc.asInstanceOf[ExpressionEncoder[Any]]), - children = dataAttributes) - } - val execution = new QueryExecution(sqlContext, withEncoders) - - val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]) - - // Rebind the encoders to the nested schema that will be produced by the aggregation. - val encoders = (kEnc +: columnEncoders).zip(execution.analyzed.output).map { - case (e: ExpressionEncoder[_], a) if !e.flat => - e.nested(a).resolve(execution.analyzed.output) - case (e, a) => - e.unbind(a :: Nil).resolve(execution.analyzed.output) - } + val encoders = columns.map(_.encoder) + println(tEnc) + val namedColumns = nameColumns(columns.map(_.withInputType(tEnc, dataAttributes))) + val execution = new QueryExecution(sqlContext, Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan)) new Dataset( sqlContext, execution, - ExpressionEncoder.tuple(encoders)) + ExpressionEncoder.tuple(kEnc +: encoders)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index 9ca383896a09..b6d3f6287e9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedAlias} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.StructType import scala.util.control.NonFatal @@ -34,4 +37,22 @@ private[sql] trait Queryable { s"Invalid tree; ${e.getMessage}:\n$queryExecution" } } + + protected def nameColumns(columns: Seq[Column]): Seq[NamedExpression] = { + columns.map { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) + + case Column(expr: NamedExpression) => expr + + // Leave an unaliased generator with an empty list of names since the analyzer will generate + // the correct defaults after the nested expression's type has been resolved. + case Column(explode: Explode) => MultiAlias(explode, Nil) + case Column(jt: JsonTuple) => MultiAlias(jt, Nil) + + case Column(expr: Expression) => Alias(expr, expr.prettyString)() + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index dfcbac8687b3..e33d1e38a20d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -78,8 +78,7 @@ case class TypedAggregateExpression( override lazy val resolved: Boolean = aEncoder.isDefined - override lazy val inputTypes: Seq[DataType] = - aEncoder.map(_.schema.map(_.dataType)).getOrElse(Nil) + override lazy val inputTypes: Seq[DataType] = Nil override val aggBufferSchema: StructType = bEncoder.schema diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index eb6fa1e72e27..46169ca07d71 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -157,7 +157,6 @@ public Integer call(Integer v1, Integer v2) throws Exception { Assert.assertEquals(6, reduced); } - @Test public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = context.createDataset(data, Encoders.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 378cd365276b..e12a1d39eab1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -82,45 +82,45 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ("a", 30), ("b", 3), ("c", 1)) } - test("typed aggregation: TypedAggregator, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - - checkAnswer( - ds.groupBy(_._1).agg( - sum(_._2), - expr("sum(_2)").as[Int], - count("*")), - ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) - } - - test("typed aggregation: complex case") { - val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() - - checkAnswer( - ds.groupBy(_._1).agg( - expr("avg(_2)").as[Double], - TypedAverage.toColumn), - ("a", 2.0, 2.0), ("b", 3.0, 3.0)) - } - - test("typed aggregation: complex result type") { - val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() - - checkAnswer( - ds.groupBy(_._1).agg( - expr("avg(_2)").as[Double], - ComplexResultAgg.toColumn), - ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) - } - - test("typed aggregation: in project list") { - val ds = Seq(1, 3, 2, 5).toDS() - - checkAnswer( - ds.select(sum((i: Int) => i)), - 11) - checkAnswer( - ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), - 11 -> 22) - } +// test("typed aggregation: TypedAggregator, expr, expr") { +// val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() +// +// checkAnswer( +// ds.groupBy(_._1).agg( +// sum(_._2), +// expr("sum(_2)").as[Int], +// count("*")), +// ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) +// } +// +// test("typed aggregation: complex case") { +// val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() +// +// checkAnswer( +// ds.groupBy(_._1).agg( +// expr("avg(_2)").as[Double], +// TypedAverage.toColumn), +// ("a", 2.0, 2.0), ("b", 3.0, 3.0)) +// } +// +// test("typed aggregation: complex result type") { +// val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() +// +// checkAnswer( +// ds.groupBy(_._1).agg( +// expr("avg(_2)").as[Double], +// ComplexResultAgg.toColumn), +// ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) +// } +// +// test("typed aggregation: in project list") { +// val ds = Seq(1, 3, 2, 5).toDS() +// +// checkAnswer( +// ds.select(sum((i: Int) => i)), +// 11) +// checkAnswer( +// ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), +// 11 -> 22) +// } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 449bab1b23d3..c23dd46d3767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -219,7 +219,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("groupBy function, reduce") { - val ds = Seq("abc", "xzy", "hello").toDS() + val ds = Seq("abc", "xyz", "hello").toDS() val agged = ds.groupBy(_.length).reduce(_ + _) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 7a8b7ae5bf26..029a7d7b2b18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -89,10 +89,13 @@ abstract class QueryTest extends PlanTest { } if (decoded != expectedAnswer.toSet) { + val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted + val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted + + val comparision = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n") fail( s"""Decoded objects do not match expected objects: - |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted} - |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted} + |$comparision |${ds.encoder.constructExpression.treeString} """.stripMargin) } From ab2dbb9fcd1146e5fdb22a595a2d1722805291a4 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 12 Nov 2015 00:51:36 -0800 Subject: [PATCH 3/4] some cleanup --- .../scala/org/apache/spark/sql/Encoder.scala | 10 +- .../catalyst/analysis/HiveTypeCoercion.scala | 1 - .../catalyst/encoders/ExpressionEncoder.scala | 132 +++++++----------- .../spark/sql/catalyst/encoders/package.scala | 11 +- .../expressions/complexTypeExtractors.scala | 1 + .../plans/logical/basicOperators.scala | 15 +- .../encoders/ExpressionEncoderSuite.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 19 +++ .../org/apache/spark/sql/DataFrame.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 53 ++++--- .../org/apache/spark/sql/GroupedDataset.scala | 59 ++++---- .../spark/sql/execution/Queryable.scala | 21 --- .../aggregate/TypedAggregateExpression.scala | 10 +- .../spark/sql/execution/basicOperators.scala | 7 +- .../spark/sql/DatasetAggregatorSuite.scala | 124 ++++++++++------ .../org/apache/spark/sql/QueryTest.scala | 6 +- 16 files changed, 257 insertions(+), 216 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 1ff7340557e6..4ba8433694a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -84,7 +84,7 @@ object Encoders { private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { assert(encoders.length > 1) // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`. - assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty)) + assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty)) val schema = StructType(encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) @@ -93,8 +93,8 @@ object Encoders { val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") val extractExpressions = encoders.map { - case e if e.flat => e.extractExpressions.head - case other => CreateStruct(other.extractExpressions) + case e if e.flat => e.toRowExpressions.head + case other => CreateStruct(other.toRowExpressions) }.zipWithIndex.map { case (expr, index) => expr.transformUp { case BoundReference(0, t: ObjectType, _) => @@ -107,11 +107,11 @@ object Encoders { val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => if (enc.flat) { - enc.constructExpression.transform { + enc.fromRowExpression.transform { case b: BoundReference => b.copy(ordinal = index) } } else { - enc.constructExpression.transformUp { + enc.fromRowExpression.transformUp { case BoundReference(ordinal, dt, _) => GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2d8060086003..bf2bff0243fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -712,7 +712,6 @@ object HiveTypeCoercion { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => - println(s"$in -> $expected") // If we cannot do the implicit cast, just use the original input. implicitCast(in, expected).getOrElse(in) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 1e9adea33202..0d3e4aafb0af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -61,31 +61,23 @@ object ExpressionEncoder { /** * Given a set of N encoders, constructs a new encoder that produce objects as items in an - * N-tuple. Note that these encoders should first be bound correctly to the combined input - * schema. + * N-tuple. Note that these encoders should be unresolved so that information about + * name/positional binding is preserved. */ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + encoders.foreach(_.assertUnresolved()) + val schema = StructType( - encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)}) + encoders.zipWithIndex.map { + case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + }) val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - // Rebind the encoders to the nested schema that will be produced by the aggregation. + // Rebind the encoders to the nested schema. val newConstructExpressions = encoders.zipWithIndex.map { - case (e, i) if !e.flat => - println(s"=== $i - nested ===") - println(e.constructExpression.treeString) - println() - println(e.nested(i).constructExpression.treeString) - - e.nested(i).constructExpression - case (e, i) => - println(s"=== $i - flat ===") - println(e.constructExpression.treeString) - println() - println(e.shift(i).constructExpression.treeString) - - e.shift(i).constructExpression + case (e, i) if !e.flat => e.nested(i).fromRowExpression + case (e, i) => e.shift(i).fromRowExpression } val constructExpression = @@ -93,11 +85,11 @@ object ExpressionEncoder { val input = BoundReference(0, ObjectType(cls), false) val extractExpressions = encoders.zipWithIndex.map { - case (e, i) if !e.flat => CreateStruct(e.extractExpressions.map(_ transformUp { + case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ transformUp { case b: BoundReference => Invoke(input, s"_${i + 1}", b.dataType, Nil) })) - case (e, i) => e.extractExpressions.head transformUp { + case (e, i) => e.toRowExpressions.head transformUp { case b: BoundReference => Invoke(input, s"_${i + 1}", b.dataType, Nil) } @@ -122,26 +114,27 @@ object ExpressionEncoder { * A generic encoder for JVM objects. * * @param schema The schema after converting `T` to a Spark SQL row. - * @param extractExpressions A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object. + * @param toRowExpressions A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object into an [[InternalRow]]. + * @param fromRowExpression An expression that will construct an object given an [[InternalRow]]. * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( schema: StructType, flat: Boolean, - extractExpressions: Seq[Expression], - constructExpression: Expression, + toRowExpressions: Seq[Expression], + fromRowExpression: Expression, clsTag: ClassTag[T]) extends Encoder[T] { - if (flat) require(extractExpressions.size == 1) + if (flat) require(toRowExpressions.size == 1) @transient - private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) + private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions) private val inputRow = new GenericMutableRow(1) @transient - private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) + private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil) /** * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to @@ -153,7 +146,8 @@ case class ExpressionEncoder[T]( extractProjection(inputRow) } catch { case e: Exception => - throw new RuntimeException(s"Error while encoding: $e\n${extractExpressions.map(_.treeString).mkString("\n")}", e) + throw new RuntimeException( + s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e) } /** @@ -165,7 +159,20 @@ case class ExpressionEncoder[T]( constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] } catch { case e: Exception => - throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e) + throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e) + } + + /** + * The process of resolution to a given schema throws away information about where a given field + * is being bound by ordinal instead of by name. This method checks to make sure this process + * has not been done already in places where we plan to do later composition of encoders. + */ + def assertUnresolved(): Unit = { + (fromRowExpression +: toRowExpressions).foreach(_.foreach { + case a: AttributeReference => + sys.error(s"Unresolved encoder expected, but $a was found.") + case _ => + }) } /** @@ -173,9 +180,14 @@ case class ExpressionEncoder[T]( * given schema. */ def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { - val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema)) + val positionToAttribute = AttributeMap.toIndex(schema) + val unbound = fromRowExpression transform { + case b: BoundReference => positionToAttribute(b.ordinal) + } + + val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) - copy(constructExpression = analyzedPlan.expressions.head.children.head) + copy(fromRowExpression = analyzedPlan.expressions.head.children.head) } /** @@ -184,39 +196,14 @@ case class ExpressionEncoder[T]( * resolve before bind. */ def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - copy(constructExpression = BindReferences.bindReference(constructExpression, schema)) + copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema)) } /** - * Replaces any bound references in the schema with the attributes at the corresponding ordinal - * in the provided schema. This can be used to "relocate" a given encoder to pull values from - * a different schema than it was initially bound to. It can also be used to assign attributes - * to ordinal based extraction (i.e. because the input data was a tuple). + * Returns a new encoder with input columns shifted by `delta` ordinals */ - def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(schema) - copy(constructExpression = constructExpression transform { - case b: BoundReference => positionToAttribute(b.ordinal) - }) - } - - /** - * Given an encoder that has already been bound to a given schema, returns a new encoder - * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example, - * when you are trying to use an encoder on grouping keys that were originally part of a larger - * row, but now you have projected out only the key expressions. - */ - def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(oldSchema) - val attributeToNewPosition = AttributeMap.byIndex(newSchema) - copy(constructExpression = constructExpression transform { - case r: BoundReference => - r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal))) - }) - } - def shift(delta: Int): ExpressionEncoder[T] = { - copy(constructExpression = constructExpression transform { + copy(fromRowExpression = fromRowExpression transform { case r: BoundReference => r.copy(ordinal = r.ordinal + delta) }) } @@ -226,9 +213,11 @@ case class ExpressionEncoder[T]( * input row have been modified to pull the object out from a nested struct, instead of the * top level fields. */ - def nested(i: Int): ExpressionEncoder[T] = { - val input = BoundReference(i, NullType, true) - copy(constructExpression = constructExpression transformUp { + private def nested(i: Int): ExpressionEncoder[T] = { + // We don't always know our input type at this point since it might be unresolved. + // We fill in null and it will get unbound to the actual attribute at this position. + val input = BoundReference(i, NullType, nullable = true) + copy(fromRowExpression = fromRowExpression transformUp { case u: Attribute => UnresolvedExtractValue(input, Literal(u.name)) case b: BoundReference => @@ -239,24 +228,7 @@ case class ExpressionEncoder[T]( }) } - /** - * Returns a copy of this encoder where the expressions used to create an object given an - * input row have been modified to pull the object out from a nested struct, instead of the - * top level fields. - */ - def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = { - copy(constructExpression = constructExpression transform { - case u: Attribute if u != input => - UnresolvedExtractValue(input, Literal(u.name)) - case b: BoundReference if b != input => - GetStructField( - input, - StructField(s"i[${b.ordinal}]", b.dataType), - b.ordinal) - }) - } - - protected val attrs = extractExpressions.flatMap(_.collect { + protected val attrs = toRowExpressions.flatMap(_.collect { case _: UnresolvedAttribute => "" case a: Attribute => s"#${a.exprId}" case b: BoundReference => s"[${b.ordinal}]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index 2c35adca9c92..9e283f5eb634 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -18,10 +18,19 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.expressions.AttributeReference package object encoders { + /** + * Returns an internal encoder object that can be used to serialize / deserialize JVM objects + * into Spark SQL rows. The implicit encoder should always be unresolved (i.e. have no attribute + * references from a specific schema.) This requirement allows us to preserve whether a given + * object type is being bound by name or by ordinal when doing resolution. + */ private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { - case e: ExpressionEncoder[A] => e + case e: ExpressionEncoder[A] => + e.assertUnresolved() + e case _ => sys.error(s"Only expression encoders are supported today") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 96652b65fa0c..dbd3d5c85acb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -97,6 +97,7 @@ object ExtractValue { * Returns the value of fields in the Struct `child`. * * No need to do type checking since it is handled by [[ExtractValue]]. + * TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]]. */ case class GetStructField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 597f03e75270..8f43716485cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -469,9 +469,12 @@ case class MapPartitions[T, U]( /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumn { - def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = { + def apply[T, U : Encoder]( + func: T => U, + tEncoder: ExpressionEncoder[T], + child: LogicalPlan): AppendColumn[T, U] = { val attrs = encoderFor[U].schema.toAttributes - new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child) + new AppendColumn[T, U](func, tEncoder, encoderFor[U], attrs, child) } } @@ -492,14 +495,16 @@ case class AppendColumn[T, U]( /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { - def apply[K : Encoder, T : Encoder, U : Encoder]( + def apply[K, T, U : Encoder]( func: (K, Iterator[T]) => TraversableOnce[U], + kEncoder: ExpressionEncoder[K], + tEncoder: ExpressionEncoder[T], groupingAttributes: Seq[Attribute], child: LogicalPlan): MapGroups[K, T, U] = { new MapGroups( func, - encoderFor[K], - encoderFor[T], + kEncoder, + tEncoder, encoderFor[U], groupingAttributes, encoderFor[U].schema.toAttributes, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index b0dacf7f555e..a5c9f9646bd3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -237,7 +237,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { } val convertedData = encoder.toRow(inputData) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.resolve(schema).bind(schema) + val boundEncoder = encoder.resolve(schema) val convertedBack = try boundEncoder.fromRow(convertedData) catch { case e: Exception => fail( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 06d9061be0c1..929224460dc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -93,6 +93,25 @@ class Column(protected[sql] val expr: Expression) extends Logging { /** Creates a column based on the given expression. */ private def withExpr(newExpr: Expression): Column = new Column(newExpr) + /** + * Returns the expression for this column either with an existing or auto assigned name. + */ + private[sql] def named: NamedExpression = expr match { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case u: UnresolvedAttribute => UnresolvedAlias(u) + + case expr: NamedExpression => expr + + // Leave an unaliased generator with an empty list of names since the analyzer will generate + // the correct defaults after the nested expression's type has been resolved. + case explode: Explode => MultiAlias(explode, Nil) + case jt: JsonTuple => MultiAlias(jt, Nil) + + case expr: Expression => Alias(expr, expr.prettyString)() + } + override def toString: String = expr.prettyString override def equals(that: Any): Boolean = that match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 553c10757cce..3ba4ba18d212 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -735,7 +735,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(nameColumns(cols), logicalPlan) + Project(cols.map(_.named), logicalPlan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 95bb4e4ad344..b930e4661c1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.types.StructType /** @@ -63,16 +62,20 @@ import org.apache.spark.sql.types.StructType class Dataset[T] private[sql]( @transient val sqlContext: SQLContext, @transient val queryExecution: QueryExecution, - implicit val unresolvedEncoder: Encoder[T]) extends Queryable with Serializable { + tEncoder: Encoder[T]) extends Queryable with Serializable { + + /** + * An unresolved version of the internal encoder for the type of this dataset. This one is marked + * implicit so that we can use it when constructing new [[Dataset]] objects that have the same + * object type (that will be possibly resolved to a different schema). + */ + private implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ - private[sql] val encoder: ExpressionEncoder[T] = unresolvedEncoder match { - case e: ExpressionEncoder[T] => - e.unbind(queryExecution.analyzed.output).resolve(queryExecution.analyzed.output) - case _ => throw new IllegalArgumentException("Only expression encoders are currently supported") - } + private[sql] val resolvedTEncoder: ExpressionEncoder[T] = + unresolvedTEncoder.resolve(queryExecution.analyzed.output) - private implicit def classTag = encoder.clsTag + private implicit def classTag = resolvedTEncoder.clsTag private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = this(sqlContext, new QueryExecution(sqlContext, plan), encoder) @@ -82,7 +85,7 @@ class Dataset[T] private[sql]( * * @since 1.6.0 */ - def schema: StructType = encoder.schema + def schema: StructType = resolvedTEncoder.schema /* ************* * * Conversions * @@ -135,7 +138,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def rdd: RDD[T] = { - val tEnc = encoder + val tEnc = resolvedTEncoder val input = queryExecution.analyzed.output queryExecution.toRdd.mapPartitions { iter => val bound = tEnc.bind(input) @@ -296,12 +299,12 @@ class Dataset[T] private[sql]( */ def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { val inputPlan = queryExecution.analyzed - val withGroupingKey = AppendColumn(func, inputPlan) + val withGroupingKey = AppendColumn(func, resolvedTEncoder, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) new GroupedDataset( - encoderFor[K].resolve(withGroupingKey.newColumns), - encoderFor[T].bind(inputPlan.output), + encoderFor[K], + encoderFor[T], executed, inputPlan.output, withGroupingKey.newColumns) @@ -361,7 +364,15 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { - new Dataset[U1](sqlContext, Project(nameColumns(c1.withInputType(encoder, queryExecution.analyzed.output) :: Nil), logicalPlan)) + // We use an unbound encoder since the expression will make up its own schema. + // TODO: This probably doesn't work if we are relying on reordering of the input class fields. + new Dataset[U1]( + sqlContext, + Project( + c1.withInputType( + resolvedTEncoder.bind(queryExecution.analyzed.output), + queryExecution.analyzed.output).named :: Nil, + logicalPlan)) } /** @@ -371,7 +382,10 @@ class Dataset[T] private[sql]( */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) - val namedColumns = nameColumns(columns.map(_.withInputType(encoder, queryExecution.analyzed.output))) + // We use an unbound encoder since the expression will make up its own schema. + // TODO: This probably doesn't work if we are relying on reordering of the input class fields. + val namedColumns = + columns.map(_.withInputType(unresolvedTEncoder, queryExecution.analyzed.output).named) val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) @@ -481,17 +495,18 @@ class Dataset[T] private[sql]( val left = this.logicalPlan val right = other.logicalPlan - val leftData = this.encoder match { + val leftData = this.unresolvedTEncoder match { case e if e.flat => Alias(left.output.head, "_1")() case _ => Alias(CreateStruct(left.output), "_1")() } - val rightData = other.encoder match { + val rightData = other.unresolvedTEncoder match { case e if e.flat => Alias(right.output.head, "_2")() case _ => Alias(CreateStruct(right.output), "_2")() } - implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(unresolvedEncoder.asInstanceOf[ExpressionEncoder[T]], other.unresolvedEncoder.asInstanceOf[ExpressionEncoder[U]]) + implicit val tuple2Encoder: Encoder[(T, U)] = + ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) withPlan[(T, U)](other) { (left, right) => Project( leftData :: rightData :: Nil, @@ -558,7 +573,7 @@ class Dataset[T] private[sql]( private[sql] def logicalPlan = queryExecution.analyzed private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = - new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), unresolvedEncoder) + new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) private[sql] def withPlan[R : Encoder]( other: Dataset[_])( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index a4ddb705ec57..ae1272ae531f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,19 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.sql.types.StructType import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.{Queryable, QueryExecution} - +import org.apache.spark.sql.execution.QueryExecution /** * :: Experimental :: @@ -44,25 +40,21 @@ import org.apache.spark.sql.execution.{Queryable, QueryExecution} */ @Experimental class GroupedDataset[K, T] private[sql]( - private val kEncoder: Encoder[K], - private val tEncoder: Encoder[T], + kEncoder: Encoder[K], + tEncoder: Encoder[T], val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], - private val groupingAttributes: Seq[Attribute]) extends Queryable with Serializable { + private val groupingAttributes: Seq[Attribute]) extends Serializable { - def schema: StructType = ??? + // Similar to [[Dataset]], we use unresolved encoders for later composition and resolved encoders + // when constructing new logical plans that will operate on the output of the current + // queryexecution. - private implicit val kEnc = kEncoder match { - case e: ExpressionEncoder[K] => e.unbind(groupingAttributes).resolve(groupingAttributes) - case other => - throw new UnsupportedOperationException("Only expression encoders are currently supported") - } + private implicit val unresolvedKEncoder = encoderFor(kEncoder) + private implicit val unresolvedTEncoder = encoderFor(tEncoder) - private implicit val tEnc = tEncoder match { - case e: ExpressionEncoder[T] => e.resolve(dataAttributes) - case other => - throw new UnsupportedOperationException("Only expression encoders are currently supported") - } + private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes) + private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes) /** Encoders for built in aggregations. */ private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) @@ -81,7 +73,7 @@ class GroupedDataset[K, T] private[sql]( def asKey[L : Encoder]: GroupedDataset[L, T] = new GroupedDataset( encoderFor[L], - tEncoder, + unresolvedTEncoder, queryExecution, dataAttributes, groupingAttributes) @@ -97,7 +89,7 @@ class GroupedDataset[K, T] private[sql]( } /** - * Applies the given function to each group of data. For each unique group, the function will + * Applies the given function to each group of data. For each unique group, the function will * be passed the group key and an iterator that contains all of the elements in the group. The * function can return an iterator containing elements of an arbitrary type which will be returned * as a new [[Dataset]]. @@ -110,7 +102,12 @@ class GroupedDataset[K, T] private[sql]( def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, - MapGroups(f, groupingAttributes, logicalPlan)) + MapGroups( + f, + resolvedKEncoder, + resolvedTEncoder, + groupingAttributes, + logicalPlan)) } def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { @@ -142,9 +139,9 @@ class GroupedDataset[K, T] private[sql]( */ def reduce(f: (T, T) => T): Dataset[(K, T)] = { val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f)) - println(ExpressionEncoder.tuple(kEnc, tEnc).constructExpression.treeString) - flatMap(func)(ExpressionEncoder.tuple(kEncoder.asInstanceOf[ExpressionEncoder[K]], tEncoder.asInstanceOf[ExpressionEncoder[T]])) + implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder) + flatMap(func) } def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = { @@ -163,14 +160,16 @@ class GroupedDataset[K, T] private[sql]( */ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) - println(tEnc) - val namedColumns = nameColumns(columns.map(_.withInputType(tEnc, dataAttributes))) - val execution = new QueryExecution(sqlContext, Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan)) + val namedColumns = + columns.map( + _.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes).named) + val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan) + val execution = new QueryExecution(sqlContext, aggregate) new Dataset( sqlContext, execution, - ExpressionEncoder.tuple(kEnc +: encoders)) + ExpressionEncoder.tuple(unresolvedKEncoder +: encoders)) } /** @@ -223,7 +222,7 @@ class GroupedDataset[K, T] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.tEncoder + implicit def uEnc: Encoder[U] = other.unresolvedTEncoder new Dataset[R]( sqlContext, CoGroup( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index b6d3f6287e9d..9ca383896a09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedAlias} -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.StructType import scala.util.control.NonFatal @@ -37,22 +34,4 @@ private[sql] trait Queryable { s"Invalid tree; ${e.getMessage}:\n$queryExecution" } } - - protected def nameColumns(columns: Seq[Column]): Seq[NamedExpression] = { - columns.map { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) - - case Column(expr: NamedExpression) => expr - - // Leave an unaliased generator with an empty list of names since the analyzer will generate - // the correct defaults after the nested expression's type has been resolved. - case Column(explode: Explode) => MultiAlias(explode, Nil) - case Column(jt: JsonTuple) => MultiAlias(jt, Nil) - - case Column(expr: Expression) => Alias(expr, expr.prettyString)() - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index e33d1e38a20d..3f2775896bb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -55,7 +55,7 @@ case class TypedAggregateExpression( aEncoder: Option[ExpressionEncoder[Any]], bEncoder: ExpressionEncoder[Any], cEncoder: ExpressionEncoder[Any], - children: Seq[Expression], + children: Seq[Attribute], mutableAggBufferOffset: Int, inputAggBufferOffset: Int) extends ImperativeAggregate with Logging { @@ -89,12 +89,8 @@ case class TypedAggregateExpression( override val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) - lazy val inputAttributes = aEncoder.get.schema.toAttributes - lazy val inputMapping = AttributeMap(inputAttributes.zip(children)) - lazy val boundA = - aEncoder.get.copy(constructExpression = aEncoder.get.constructExpression transform { - case a: AttributeReference => inputMapping(a) - }) + // We let the dataset do the binding for us. + lazy val boundA = aEncoder.get val bAttributes = bEncoder.schema.toAttributes lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 303d636164ad..2cef3231fbd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -337,6 +337,10 @@ case class AppendColumns[T, U]( newColumns: Seq[Attribute], child: SparkPlan) extends UnaryNode { + // We are using an unsafe combiner. + override def canProcessSafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def output: Seq[Attribute] = child.output ++ newColumns override protected def doExecute(): RDD[InternalRow] = { @@ -375,11 +379,12 @@ case class MapGroups[K, T, U]( child.execute().mapPartitions { iter => val grouped = GroupedIterator(iter, groupingAttributes, child.output) val groupKeyEncoder = kEncoder.bind(groupingAttributes) + val groupDataEncoder = tEncoder.bind(child.output) grouped.flatMap { case (key, rowIter) => val result = func( groupKeyEncoder.fromRow(key), - rowIter.map(tEncoder.fromRow)) + rowIter.map(groupDataEncoder.fromRow)) result.map(uEncoder.toRow) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index e12a1d39eab1..20896efdfec1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -67,6 +67,28 @@ object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, L override def finish(reduction: (Long, Long)): (Long, Long) = reduction } +case class AggData(a: Int, b: String) +object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable { + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + override def zero: Int = 0 + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + override def reduce(b: Int, a: AggData): Int = b + a.a + + /** + * Transform the output of the reduction. + */ + override def finish(reduction: Int): Int = reduction + + /** + * Merge two intermediate values + */ + override def merge(b1: Int, b2: Int): Int = b1 + b2 +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -82,45 +104,65 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ("a", 30), ("b", 3), ("c", 1)) } -// test("typed aggregation: TypedAggregator, expr, expr") { -// val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() -// -// checkAnswer( -// ds.groupBy(_._1).agg( -// sum(_._2), -// expr("sum(_2)").as[Int], -// count("*")), -// ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) -// } -// -// test("typed aggregation: complex case") { -// val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() -// -// checkAnswer( -// ds.groupBy(_._1).agg( -// expr("avg(_2)").as[Double], -// TypedAverage.toColumn), -// ("a", 2.0, 2.0), ("b", 3.0, 3.0)) -// } -// -// test("typed aggregation: complex result type") { -// val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() -// -// checkAnswer( -// ds.groupBy(_._1).agg( -// expr("avg(_2)").as[Double], -// ComplexResultAgg.toColumn), -// ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) -// } -// -// test("typed aggregation: in project list") { -// val ds = Seq(1, 3, 2, 5).toDS() -// -// checkAnswer( -// ds.select(sum((i: Int) => i)), -// 11) -// checkAnswer( -// ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), -// 11 -> 22) -// } + test("typed aggregation: TypedAggregator, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + sum(_._2), + expr("sum(_2)").as[Int], + count("*")), + ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) + } + + test("typed aggregation: complex case") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + expr("avg(_2)").as[Double], + TypedAverage.toColumn), + ("a", 2.0, 2.0), ("b", 3.0, 3.0)) + } + + test("typed aggregation: complex result type") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + expr("avg(_2)").as[Double], + ComplexResultAgg.toColumn), + ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) + } + + test("typed aggregation: in project list") { + val ds = Seq(1, 3, 2, 5).toDS() + + checkAnswer( + ds.select(sum((i: Int) => i)), + 11) + checkAnswer( + ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), + 11 -> 22) + } + + test("typed aggregation: class input") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + + checkAnswer( + ds.select(ClassInputAgg.toColumn), + 3) + } + + test("typed aggregation: class input with reordering") { + val ds = sql("SELECT 'one' AS b, 1 as a").as[AggData] + + checkAnswer( + ds.select(ClassInputAgg.toColumn), + 1) + + checkAnswer( + ds.groupBy(_.b).agg(ClassInputAgg.toColumn), + ("one", 1)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 029a7d7b2b18..b5417b195f39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -82,8 +82,8 @@ abstract class QueryTest extends PlanTest { fail( s""" |Exception collecting dataset as objects - |${ds.encoder} - |${ds.encoder.constructExpression.treeString} + |${ds.resolvedTEncoder} + |${ds.resolvedTEncoder.fromRowExpression.treeString} |${ds.queryExecution} """.stripMargin, e) } @@ -96,7 +96,7 @@ abstract class QueryTest extends PlanTest { fail( s"""Decoded objects do not match expected objects: |$comparision - |${ds.encoder.constructExpression.treeString} + |${ds.resolvedTEncoder.fromRowExpression.treeString} """.stripMargin) } } From f2474970c1b4b4de44fd30f80ef7fdd73742facc Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 12 Nov 2015 13:20:03 -0800 Subject: [PATCH 4/4] fix tests --- .../sql/catalyst/expressions/complexTypeExtractors.scala | 6 +++++- .../sql/catalyst/encoders/ExpressionEncoderSuite.scala | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index dbd3d5c85acb..f871b737fff3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -102,7 +102,11 @@ object ExtractValue { case class GetStructField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression { - override def dataType: DataType = child.dataType.asInstanceOf[StructType](ordinal).dataType + override def dataType: DataType = child.dataType match { + case s: StructType => s(ordinal).dataType + // This is a hack to avoid breaking existing code until we remove the need for the struct field + case _ => field.dataType + } override def nullable: Boolean = child.nullable || field.nullable override def toString: String = s"$child.${field.name}" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index a5c9f9646bd3..b0dacf7f555e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -237,7 +237,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { } val convertedData = encoder.toRow(inputData) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.resolve(schema) + val boundEncoder = encoder.resolve(schema).bind(schema) val convertedBack = try boundEncoder.fromRow(convertedData) catch { case e: Exception => fail(