Skip to content

Commit fa3c90b

Browse files
belieferfishcus
authored andcommitted
[SPARK-37203][SQL] Fix NotSerializableException when observe with TypedImperativeAggregate
Currently, ``` val namedObservation = Observation("named") val df = spark.range(100) val observed_df = df.observe( namedObservation, percentile_approx($"id", lit(0.5), lit(100)).as("percentile_approx_val")) observed_df.collect() namedObservation.get ``` throws exception as follows: ``` 15:16:27.994 ERROR org.apache.spark.util.Utils: Exception encountered java.io.NotSerializableException: org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile$PercentileDigest at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1184) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1378) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1174) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348) at org.apache.spark.scheduler.DirectTaskResult.$anonfun$writeExternal$2(TaskResult.scala:55) at org.apache.spark.scheduler.DirectTaskResult.$anonfun$writeExternal$2$adapted(TaskResult.scala:55) at scala.collection.Iterator.foreach(Iterator.scala:943) at scala.collection.Iterator.foreach$(Iterator.scala:943) at scala.collection.AbstractIterator.foreach(Iterator.scala:1431) at scala.collection.IterableLike.foreach(IterableLike.scala:74) at scala.collection.IterableLike.foreach$(IterableLike.scala:73) at scala.collection.AbstractIterable.foreach(Iterable.scala:56) at org.apache.spark.scheduler.DirectTaskResult.$anonfun$writeExternal$1(TaskResult.scala:55) at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23) at org.apache.spark.util.Utils$.tryOrIOException(Utils.scala:1434) at org.apache.spark.scheduler.DirectTaskResult.writeExternal(TaskResult.scala:51) at java.io.ObjectOutputStream.writeExternalData(ObjectOutputStream.java:1459) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1430) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348) at org.apache.spark.serializer.JavaSerializationStream.writeObject(JavaSerializer.scala:44) at org.apache.spark.serializer.JavaSerializerInstance.serialize(JavaSerializer.scala:101) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:616) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` This PR will fix the issue. After the change, `assert(namedObservation.get === Map("percentile_approx_val" -> 49))` `java.io.NotSerializableException` will not happen. Fix `NotSerializableException` when observe with `TypedImperativeAggregate`. No. This PR change the implement of `AggregatingAccumulator` who uses serialize and deserialize of `TypedImperativeAggregate` now. New tests. Closes apache#34474 from beliefer/SPARK-37203. Authored-by: Jiaan Geng <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit 3f3201a) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 640d88b commit fa3c90b

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
155155
*/
156156
def value: OUT
157157

158+
// Serialize the buffer of this accumulator before sending back this accumulator to the driver.
159+
// By default this method does nothing.
160+
protected def withBufferSerialized(): AccumulatorV2[IN, OUT] = this
161+
158162
// Called by Java when serializing an object
159163
final protected def writeReplace(): Any = {
160164
if (atDriverSide) {
@@ -177,7 +181,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
177181
}
178182
copyAcc
179183
} else {
180-
this
184+
withBufferSerialized()
181185
}
182186
}
183187

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ object ApproximatePercentile {
274274
}
275275

276276
/**
277-
* Serializer for class [[PercentileDigest]]
277+
* Serializer for class [[PercentileDigest]]
278278
*
279279
* This class is thread safe.
280280
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,18 @@ class AggregatingAccumulator private(
162162
i += 1
163163
}
164164
i = 0
165-
while (i < typedImperatives.length) {
166-
typedImperatives(i).mergeBuffersObjects(buffer, otherBuffer)
167-
i += 1
165+
if (isAtDriverSide) {
166+
while (i < typedImperatives.length) {
167+
// The input buffer stores serialized data
168+
typedImperatives(i).merge(buffer, otherBuffer)
169+
i += 1
170+
}
171+
} else {
172+
while (i < typedImperatives.length) {
173+
// The input buffer stores deserialized object
174+
typedImperatives(i).mergeBuffersObjects(buffer, otherBuffer)
175+
i += 1
176+
}
168177
}
169178
case _ =>
170179
throw new UnsupportedOperationException(
@@ -187,6 +196,17 @@ class AggregatingAccumulator private(
187196
resultProjection(input)
188197
}
189198

199+
override def withBufferSerialized(): AggregatingAccumulator = {
200+
assert(!isAtDriverSide)
201+
var i = 0
202+
// AggregatingAccumulator runs on executor, we should serialize all TypedImperativeAggregate.
203+
while (i < typedImperatives.length) {
204+
typedImperatives(i).serializeAggregateBufferInPlace(buffer)
205+
i += 1
206+
}
207+
this
208+
}
209+
190210
/**
191211
* Get the output schema of the aggregating accumulator.
192212
*/

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,8 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
417417
min($"value").as("min_val"),
418418
max($"value").as("max_val"),
419419
sum($"value").as("sum_val"),
420-
count(when($"value" % 2 === 0, 1)).as("num_even"))
420+
count(when($"value" % 2 === 0, 1)).as("num_even"),
421+
percentile_approx($"value", lit(0.5), lit(100)).as("percentile_approx_val"))
421422
.observe(
422423
name = "other_event",
423424
avg($"value").cast("int").as("avg_val"))
@@ -444,15 +445,15 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
444445
AddData(inputData, 1, 2),
445446
AdvanceManualClock(100),
446447
checkMetrics { metrics =>
447-
assert(metrics.get("my_event") === Row(1, 2, 3L, 1L))
448+
assert(metrics.get("my_event") === Row(1, 2, 3L, 1L, 1))
448449
assert(metrics.get("other_event") === Row(1))
449450
},
450451

451452
// Batch 2
452453
AddData(inputData, 10, 30, -10, 5),
453454
AdvanceManualClock(100),
454455
checkMetrics { metrics =>
455-
assert(metrics.get("my_event") === Row(-10, 30, 35L, 3L))
456+
assert(metrics.get("my_event") === Row(-10, 30, 35L, 3L, 5))
456457
assert(metrics.get("other_event") === Row(8))
457458
},
458459

0 commit comments

Comments
 (0)