From 1eacf82ae9ba0173b842a671cde8c82a4032a0cd Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 3 Nov 2021 15:58:10 +0800 Subject: [PATCH 1/6] Fix NotSerializableException when observe with percentile_approx --- .../expressions/aggregate/ApproximatePercentile.scala | 4 ++-- .../scala/org/apache/spark/sql/DatasetSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 8cce79c43c53..f9380a154edb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -226,7 +226,7 @@ object ApproximatePercentile { * * @param summaries underlying probabilistic data structure [[QuantileSummaries]]. */ - class PercentileDigest(private var summaries: QuantileSummaries) { + class PercentileDigest(private var summaries: QuantileSummaries) extends Serializable { def this(relativeError: Double) = { this(new QuantileSummaries(defaultCompressThreshold, relativeError, compressed = true)) @@ -276,7 +276,7 @@ object ApproximatePercentile { } /** - * Serializer for class [[PercentileDigest]] + * Serializer for class [[PercentileDigest]] * * This class is thread safe. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a46ef5d8e9cd..6aadbab99f27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -754,6 +754,17 @@ class DatasetSuite extends QueryTest assert(err2.getMessage.contains("Name must not be empty")) } + test("SPARK-37203: Fix NotSerializableException when observe with percentile_approx") { + val namedObservation = Observation("named") + + val df = spark.range(100) + val observed_df = df.observe( + namedObservation, percentile_approx($"id", lit(0.5), lit(100)).as("percentile_approx_val")) + + observed_df.collect() + assert(namedObservation.get === Map("percentile_approx_val" -> 49)) + } + test("sample with replacement") { val n = 100 val data = sparkContext.parallelize(1 to n, 2).toDS() From a2f1fbbb2f74afefe802624700887ae89dfc7591 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 3 Nov 2021 17:13:05 +0800 Subject: [PATCH 2/6] Add test --- .../spark/sql/streaming/StreamingQueryListenerSuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index cf1b6bcd0318..e729fe32ebaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -417,7 +417,8 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { min($"value").as("min_val"), max($"value").as("max_val"), sum($"value").as("sum_val"), - count(when($"value" % 2 === 0, 1)).as("num_even")) + count(when($"value" % 2 === 0, 1)).as("num_even"), + percentile_approx($"value", lit(0.5), lit(100)).as("percentile_approx_val")) .observe( name = "other_event", avg($"value").cast("int").as("avg_val")) @@ -444,7 +445,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { AddData(inputData, 1, 2), AdvanceManualClock(100), checkMetrics { metrics => - assert(metrics.get("my_event") === Row(1, 2, 3L, 1L)) + assert(metrics.get("my_event") === Row(1, 2, 3L, 1L, 1)) assert(metrics.get("other_event") === Row(1)) }, @@ -452,7 +453,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { AddData(inputData, 10, 30, -10, 5), AdvanceManualClock(100), checkMetrics { metrics => - assert(metrics.get("my_event") === Row(-10, 30, 35L, 3L)) + assert(metrics.get("my_event") === Row(-10, 30, 35L, 3L, 5)) assert(metrics.get("other_event") === Row(8)) }, From 6c2f71c6c1ee27941499fafc0b1fad59b646caa2 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 4 Nov 2021 16:40:36 +0800 Subject: [PATCH 3/6] Improve code --- .../org/apache/spark/util/AccumulatorV2.scala | 5 +++- .../aggregate/ApproximatePercentile.scala | 2 +- .../expressions/aggregate/interfaces.scala | 11 +++++++ .../execution/AggregatingAccumulator.scala | 30 +++++++++++++++++-- 4 files changed, 44 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 16e58ad8f6a1..bad54e2de3bb 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -157,6 +157,9 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { */ def value: OUT + // We assume that serialization of AccumulatorV2 runs on executor is not necessary. + protected def withBufferSerialized(): AccumulatorV2[IN, OUT] = this + // Called by Java when serializing an object final protected def writeReplace(): Any = { if (atDriverSide) { @@ -179,7 +182,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } copyAcc } else { - this + withBufferSerialized() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index f9380a154edb..3eeee557287d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -226,7 +226,7 @@ object ApproximatePercentile { * * @param summaries underlying probabilistic data structure [[QuantileSummaries]]. */ - class PercentileDigest(private var summaries: QuantileSummaries) extends Serializable { + class PercentileDigest(private var summaries: QuantileSummaries) { def this(relativeError: Double) = { this(new QuantileSummaries(defaultCompressThreshold, relativeError, compressed = true)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 6c22d87923cd..78534d82fc54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -621,6 +621,17 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { buffer(mutableAggBufferOffset) = serialize(getBufferObject(buffer)) } + /** + * In-place replaces SparkSQL internally supported underlying storage format (BinaryType), + * with the aggregation buffer object stored at buffer's index `mutableAggBufferOffset`. + * + * This is only called when AggregatingAccumulator running on driver, after the framework + * shuffle in aggregate buffers. + */ + final def deserializeAggregateBufferInPlace(buffer: InternalRow): Unit = { + buffer(mutableAggBufferOffset) = deserialize(buffer.getBinary(inputAggBufferOffset)) + } + /** * Merge an input buffer into the aggregation buffer, where both buffers contain the deserialized * java object. This function is used by aggregating accumulators. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala index f884e8efad36..e145b05a57dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala @@ -156,6 +156,15 @@ class AggregatingAccumulator private( case agg: AggregatingAccumulator => val buffer = getOrCreateBuffer() val otherBuffer = agg.buffer + // If AggregatingAccumulator runs on driver, + // we should deserialize all TypedImperativeAggregate. + if (isAtDriverSide) { + var i = 0 + while (i < typedImperatives.length) { + typedImperatives(i).deserializeAggregateBufferInPlace(otherBuffer) + i += 1 + } + } mergeProjection.target(buffer)(joinedRow.withRight(otherBuffer)) var i = 0 while (i < imperatives.length) { @@ -174,9 +183,9 @@ class AggregatingAccumulator private( } } - override def value: InternalRow = withSQLConf(false, InternalRow.empty) { + private def getOrCreateTempBuffer(): SpecificInternalRow = { // Either use the existing buffer or create a temporary one. - val input = if (!isZero) { + if (!isZero) { buffer } else { // Create a temporary buffer because we want to avoid changing the state of the accumulator @@ -185,9 +194,26 @@ class AggregatingAccumulator private( // query execution). createBuffer() } + } + + override def value: InternalRow = withSQLConf(false, InternalRow.empty) { + val input = getOrCreateTempBuffer() resultProjection(input) } + override def withBufferSerialized(): AggregatingAccumulator = { + if (!isAtDriverSide) { + val input = getOrCreateTempBuffer() + var i = 0 + // AggregatingAccumulator runs on executor, we should serialize all TypedImperativeAggregate. + while (i < typedImperatives.length) { + typedImperatives(i).serializeAggregateBufferInPlace(input) + i += 1 + } + } + this + } + /** * Get the output schema of the aggregating accumulator. */ From 84f96570917e646efdf196d37fb9a68eb2e94309 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 4 Nov 2021 17:24:37 +0800 Subject: [PATCH 4/6] Improve code --- .../spark/sql/execution/AggregatingAccumulator.scala | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala index e145b05a57dc..1aabeb9791c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala @@ -183,9 +183,9 @@ class AggregatingAccumulator private( } } - private def getOrCreateTempBuffer(): SpecificInternalRow = { + override def value: InternalRow = withSQLConf(false, InternalRow.empty) { // Either use the existing buffer or create a temporary one. - if (!isZero) { + val input = if (!isZero) { buffer } else { // Create a temporary buffer because we want to avoid changing the state of the accumulator @@ -194,20 +194,15 @@ class AggregatingAccumulator private( // query execution). createBuffer() } - } - - override def value: InternalRow = withSQLConf(false, InternalRow.empty) { - val input = getOrCreateTempBuffer() resultProjection(input) } override def withBufferSerialized(): AggregatingAccumulator = { if (!isAtDriverSide) { - val input = getOrCreateTempBuffer() var i = 0 // AggregatingAccumulator runs on executor, we should serialize all TypedImperativeAggregate. while (i < typedImperatives.length) { - typedImperatives(i).serializeAggregateBufferInPlace(input) + typedImperatives(i).serializeAggregateBufferInPlace(buffer) i += 1 } } From 5e7afacde67bd1c943006250d3aebfbc72cdc5c9 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 4 Nov 2021 23:28:09 +0800 Subject: [PATCH 5/6] Update code --- .../org/apache/spark/util/AccumulatorV2.scala | 3 +- .../expressions/aggregate/interfaces.scala | 11 ------ .../execution/AggregatingAccumulator.scala | 37 +++++++++---------- .../org/apache/spark/sql/DatasetSuite.scala | 2 +- 4 files changed, 21 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index bad54e2de3bb..181033c9d20c 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -157,7 +157,8 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { */ def value: OUT - // We assume that serialization of AccumulatorV2 runs on executor is not necessary. + // Serialize the buffer of this accumulator before sending back this accumulator to the driver. + // By default this method does nothing. protected def withBufferSerialized(): AccumulatorV2[IN, OUT] = this // Called by Java when serializing an object diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 78534d82fc54..6c22d87923cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -621,17 +621,6 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { buffer(mutableAggBufferOffset) = serialize(getBufferObject(buffer)) } - /** - * In-place replaces SparkSQL internally supported underlying storage format (BinaryType), - * with the aggregation buffer object stored at buffer's index `mutableAggBufferOffset`. - * - * This is only called when AggregatingAccumulator running on driver, after the framework - * shuffle in aggregate buffers. - */ - final def deserializeAggregateBufferInPlace(buffer: InternalRow): Unit = { - buffer(mutableAggBufferOffset) = deserialize(buffer.getBinary(inputAggBufferOffset)) - } - /** * Merge an input buffer into the aggregation buffer, where both buffers contain the deserialized * java object. This function is used by aggregating accumulators. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala index 1aabeb9791c4..d528e9114baa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala @@ -156,15 +156,6 @@ class AggregatingAccumulator private( case agg: AggregatingAccumulator => val buffer = getOrCreateBuffer() val otherBuffer = agg.buffer - // If AggregatingAccumulator runs on driver, - // we should deserialize all TypedImperativeAggregate. - if (isAtDriverSide) { - var i = 0 - while (i < typedImperatives.length) { - typedImperatives(i).deserializeAggregateBufferInPlace(otherBuffer) - i += 1 - } - } mergeProjection.target(buffer)(joinedRow.withRight(otherBuffer)) var i = 0 while (i < imperatives.length) { @@ -172,9 +163,18 @@ class AggregatingAccumulator private( i += 1 } i = 0 - while (i < typedImperatives.length) { - typedImperatives(i).mergeBuffersObjects(buffer, otherBuffer) - i += 1 + if (isAtDriverSide) { + while (i < typedImperatives.length) { + // The input buffer stores serialized data + typedImperatives(i).merge(buffer, otherBuffer) + i += 1 + } + } else { + while (i < typedImperatives.length) { + // The input buffer stores deserialized object + typedImperatives(i).mergeBuffersObjects(buffer, otherBuffer) + i += 1 + } } case _ => throw QueryExecutionErrors.cannotMergeClassWithOtherClassError( @@ -198,13 +198,12 @@ class AggregatingAccumulator private( } override def withBufferSerialized(): AggregatingAccumulator = { - if (!isAtDriverSide) { - var i = 0 - // AggregatingAccumulator runs on executor, we should serialize all TypedImperativeAggregate. - while (i < typedImperatives.length) { - typedImperatives(i).serializeAggregateBufferInPlace(buffer) - i += 1 - } + assert(!isAtDriverSide) + var i = 0 + // AggregatingAccumulator runs on executor, we should serialize all TypedImperativeAggregate. + while (i < typedImperatives.length) { + typedImperatives(i).serializeAggregateBufferInPlace(buffer) + i += 1 } this } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6aadbab99f27..94c6a2a02f21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -754,7 +754,7 @@ class DatasetSuite extends QueryTest assert(err2.getMessage.contains("Name must not be empty")) } - test("SPARK-37203: Fix NotSerializableException when observe with percentile_approx") { + test("SPARK-37203: Fix NotSerializableException when observe with TypedImperativeAggregate") { val namedObservation = Observation("named") val df = spark.range(100) From acddbf34e66e70ad64d04b03d300e87049c9b73e Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 5 Nov 2021 10:01:27 +0800 Subject: [PATCH 6/6] Update code --- .../org/apache/spark/sql/DatasetSuite.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 94c6a2a02f21..2ca3f202a6e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -755,14 +755,16 @@ class DatasetSuite extends QueryTest } test("SPARK-37203: Fix NotSerializableException when observe with TypedImperativeAggregate") { - val namedObservation = Observation("named") - - val df = spark.range(100) - val observed_df = df.observe( - namedObservation, percentile_approx($"id", lit(0.5), lit(100)).as("percentile_approx_val")) + def observe[T](df: Dataset[T], expected: Map[String, _]): Unit = { + val namedObservation = Observation("named") + val observed_df = df.observe( + namedObservation, percentile_approx($"id", lit(0.5), lit(100)).as("percentile_approx_val")) + observed_df.collect() + assert(namedObservation.get === expected) + } - observed_df.collect() - assert(namedObservation.get === Map("percentile_approx_val" -> 49)) + observe(spark.range(100), Map("percentile_approx_val" -> 49)) + observe(spark.range(0), Map("percentile_approx_val" -> null)) } test("sample with replacement") {