diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 16e58ad8f6a11..181033c9d20c8 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -157,6 +157,10 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { */ def value: OUT + // Serialize the buffer of this accumulator before sending back this accumulator to the driver. + // By default this method does nothing. + protected def withBufferSerialized(): AccumulatorV2[IN, OUT] = this + // Called by Java when serializing an object final protected def writeReplace(): Any = { if (atDriverSide) { @@ -179,7 +183,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } copyAcc } else { - this + withBufferSerialized() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 8cce79c43c533..3eeee557287d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -276,7 +276,7 @@ object ApproximatePercentile { } /** - * Serializer for class [[PercentileDigest]] + * Serializer for class [[PercentileDigest]] * * This class is thread safe. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala index f884e8efad363..d528e9114baa6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala @@ -163,9 +163,18 @@ class AggregatingAccumulator private( i += 1 } i = 0 - while (i < typedImperatives.length) { - typedImperatives(i).mergeBuffersObjects(buffer, otherBuffer) - i += 1 + if (isAtDriverSide) { + while (i < typedImperatives.length) { + // The input buffer stores serialized data + typedImperatives(i).merge(buffer, otherBuffer) + i += 1 + } + } else { + while (i < typedImperatives.length) { + // The input buffer stores deserialized object + typedImperatives(i).mergeBuffersObjects(buffer, otherBuffer) + i += 1 + } } case _ => throw QueryExecutionErrors.cannotMergeClassWithOtherClassError( @@ -188,6 +197,17 @@ class AggregatingAccumulator private( resultProjection(input) } + override def withBufferSerialized(): AggregatingAccumulator = { + assert(!isAtDriverSide) + var i = 0 + // AggregatingAccumulator runs on executor, we should serialize all TypedImperativeAggregate. + while (i < typedImperatives.length) { + typedImperatives(i).serializeAggregateBufferInPlace(buffer) + i += 1 + } + this + } + /** * Get the output schema of the aggregating accumulator. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a46ef5d8e9cdb..2ca3f202a6e94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -754,6 +754,19 @@ class DatasetSuite extends QueryTest assert(err2.getMessage.contains("Name must not be empty")) } + test("SPARK-37203: Fix NotSerializableException when observe with TypedImperativeAggregate") { + def observe[T](df: Dataset[T], expected: Map[String, _]): Unit = { + val namedObservation = Observation("named") + val observed_df = df.observe( + namedObservation, percentile_approx($"id", lit(0.5), lit(100)).as("percentile_approx_val")) + observed_df.collect() + assert(namedObservation.get === expected) + } + + observe(spark.range(100), Map("percentile_approx_val" -> 49)) + observe(spark.range(0), Map("percentile_approx_val" -> null)) + } + test("sample with replacement") { val n = 100 val data = sparkContext.parallelize(1 to n, 2).toDS() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index cf1b6bcd0318c..e729fe32ebafa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -417,7 +417,8 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { min($"value").as("min_val"), max($"value").as("max_val"), sum($"value").as("sum_val"), - count(when($"value" % 2 === 0, 1)).as("num_even")) + count(when($"value" % 2 === 0, 1)).as("num_even"), + percentile_approx($"value", lit(0.5), lit(100)).as("percentile_approx_val")) .observe( name = "other_event", avg($"value").cast("int").as("avg_val")) @@ -444,7 +445,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { AddData(inputData, 1, 2), AdvanceManualClock(100), checkMetrics { metrics => - assert(metrics.get("my_event") === Row(1, 2, 3L, 1L)) + assert(metrics.get("my_event") === Row(1, 2, 3L, 1L, 1)) assert(metrics.get("other_event") === Row(1)) }, @@ -452,7 +453,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { AddData(inputData, 10, 30, -10, 5), AdvanceManualClock(100), checkMetrics { metrics => - assert(metrics.get("my_event") === Row(-10, 30, 35L, 3L)) + assert(metrics.get("my_event") === Row(-10, 30, 35L, 3L, 5)) assert(metrics.get("other_event") === Row(8)) },