Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,30 @@ class ReplSuite extends SparkFunSuite {
}
}

test("Datasets agg type-inference") {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|/** An `Aggregator` that adds up any numeric type returned by the given function. */
|class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
| val numeric = implicitly[Numeric[N]]
| override def zero: N = numeric.zero
| override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
| override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
| override def finish(reduction: N): N = reduction
|}
|
|def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
|val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
|ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dragos any idea why it fails to infer the type only in the REPL?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't reproduce the difference. It won't infer it in a standalone program either.

As I mentioned in our conversation, it's a chicken and egg problem: type inference is guided by the expected type, but if the method is overloaded, the expected type is not known. And the argument type is what guides overload resolution. It works in simple cases, when the overloads have different aritites, but with varargs that's no longer the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It always works fine here though. Why is this different?

""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
Expand Down
27 changes: 3 additions & 24 deletions sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,31 +146,10 @@ class GroupedDataset[K, T] private[sql](
reduce(f.call _)
}

/**
* Compute aggregates by specifying a series of aggregate columns, and return a [[DataFrame]].
* We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to [[Dataset]] again.
*
* The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
*
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
*
* // Scala:
* import org.apache.spark.sql.functions._
* df.groupBy("department").agg(max("age"), sum("expense"))
*
* // Java:
* import static org.apache.spark.sql.functions.*;
* df.groupBy("department").agg(max("age"), sum("expense"));
* }}}
*
* We can also use `Aggregator.toColumn` to pass in typed aggregate functions.
*
* @since 1.6.0
*/
// This is here to prevent us from adding overloads that would be ambiguous.
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame =
groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*)
private def agg(exprs: Column*): DataFrame =
groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*)

private def withEncoder(c: Column): Column = c match {
case tc: TypedColumn[_, _] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,11 +404,9 @@ public String call(Tuple2<String, Integer> value) throws Exception {
grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());

Dataset<Tuple4<String, Integer, Long, Long>> agged2 = grouped.agg(
new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()),
expr("sum(_2)"),
count("*"))
.as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG()));
Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
.as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
Assert.assertEquals(
Arrays.asList(
new Tuple4<>("a", 3, 3L, 2L),
Expand Down