Skip to content

Commit 9c16fe6

Browse files
committed
Fix a couple tests and move getAutoReset to KryoSerializerInstance
1 parent 6c54e06 commit 9c16fe6

File tree

6 files changed

+33
-25
lines changed

6 files changed

+33
-25
lines changed

core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,17 +124,6 @@ class KryoSerializer(conf: SparkConf)
124124
override def newInstance(): SerializerInstance = {
125125
new KryoSerializerInstance(this)
126126
}
127-
128-
/**
129-
* Returns true if auto-reset is on. The only reason this would be false is if the user-supplied
130-
* register explicitly turns auto-reset off.
131-
*/
132-
def getAutoReset(): Boolean = {
133-
val kryo = newKryo()
134-
val field = classOf[Kryo].getDeclaredField("autoReset")
135-
field.setAccessible(true)
136-
field.get(kryo).asInstanceOf[Boolean]
137-
}
138127
}
139128

140129
private[spark]
@@ -210,6 +199,16 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
210199
override def deserializeStream(s: InputStream): DeserializationStream = {
211200
new KryoDeserializationStream(kryo, s)
212201
}
202+
203+
/**
204+
* Returns true if auto-reset is on. The only reason this would be false is if the user-supplied
205+
* registrator explicitly turns auto-reset off.
206+
*/
207+
def getAutoReset(): Boolean = {
208+
val field = classOf[Kryo].getDeclaredField("autoReset")
209+
field.setAccessible(true)
210+
field.get(kryo).asInstanceOf[Boolean]
211+
}
213212
}
214213

215214
/**

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import scala.collection.mutable
2626
import com.google.common.io.ByteStreams
2727

2828
import org.apache.spark._
29-
import org.apache.spark.serializer.{DeserializationStream, KryoSerializer, Serializer}
29+
import org.apache.spark.serializer._
3030
import org.apache.spark.executor.ShuffleWriteMetrics
3131
import org.apache.spark.storage.{BlockObjectWriter, BlockId}
3232

@@ -129,14 +129,15 @@ private[spark] class ExternalSorter[K, V, C](
129129
private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
130130
private val useSerializedPairBuffer =
131131
!ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
132-
ser.isInstanceOf[KryoSerializer] && ser.asInstanceOf[KryoSerializer].getAutoReset
132+
ser.isInstanceOf[KryoSerializer] &&
133+
serInstance.asInstanceOf[KryoSerializerInstance].getAutoReset
133134

134135
// Data structures to store in-memory objects before we spill. Depending on whether we have an
135136
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
136137
// store them in an array buffer.
137138
private var map = new PartitionedAppendOnlyMap[K, C]
138139
private var buffer = if (useSerializedPairBuffer) {
139-
new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize)
140+
new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
140141
} else {
141142
new PartitionedPairBuffer[K, C]
142143
}
@@ -237,7 +238,7 @@ private[spark] class ExternalSorter[K, V, C](
237238
} else {
238239
if (maybeSpill(buffer, buffer.estimateSize())) {
239240
buffer = if (useSerializedPairBuffer) {
240-
new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize)
241+
new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
241242
} else {
242243
new PartitionedPairBuffer[K, C]
243244
}

core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
4848
private[spark] class PartitionedSerializedPairBuffer[K, V](
4949
metaInitialRecords: Int,
5050
kvBlockSize: Int,
51-
serializerInstance: SerializerInstance = SparkEnv.get.serializer.newInstance)
51+
serializerInstance: SerializerInstance)
5252
extends WritablePartitionedPairCollection[K, V] with SizeTracker {
5353

5454
if (serializerInstance.isInstanceOf[JavaSerializerInstance]) {

core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,11 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
282282
}
283283

284284
test("getAutoReset") {
285-
val ser = new KryoSerializer(new SparkConf)
285+
val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance]
286286
assert(ser.getAutoReset)
287287
val conf = new SparkConf().set("spark.kryo.registrator",
288288
classOf[RegistratorWithoutAutoReset].getName)
289-
val ser2 = new KryoSerializer(conf)
289+
val ser2 = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance]
290290
assert(!ser2.getAutoReset)
291291
}
292292
}

core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@ class ChainedBufferSuite extends FunSuite {
5252
}
5353

5454
test("write and read at middle") {
55-
// write from start of source array
5655
val buffer = new ChainedBuffer(8)
56+
57+
// fill to a middle point
58+
verifyWriteAndRead(buffer, 0, 0, 0, 3)
59+
60+
// write from start of source array
5761
verifyWriteAndRead(buffer, 3, 0, 0, 4)
5862
buffer.capacity should be (8)
5963

@@ -79,8 +83,12 @@ class ChainedBufferSuite extends FunSuite {
7983
}
8084

8185
test("write and read at later buffer") {
82-
// write from start of source array
8386
val buffer = new ChainedBuffer(8)
87+
88+
// fill to a middle point
89+
verifyWriteAndRead(buffer, 0, 0, 0, 11)
90+
91+
// write from start of source array
8492
verifyWriteAndRead(buffer, 11, 0, 0, 4)
8593
buffer.capacity should be (16)
8694

core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class PartitionedSerializedPairBufferSuite extends FunSuite {
5252
val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
5353
val struct = SomeStruct("something", 5)
5454
buffer.insert(4, 10, struct)
55-
val elements = buffer.partitionedDestructiveSortedIterator(null).toArray
55+
val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
5656
elements.size should be (1)
5757
elements.head should be (((4, 10), struct))
5858
}
@@ -67,7 +67,7 @@ class PartitionedSerializedPairBufferSuite extends FunSuite {
6767
val struct3 = SomeStruct("something3", 10)
6868
buffer.insert(5, 3, struct3)
6969

70-
val elements = buffer.partitionedDestructiveSortedIterator(null).toArray
70+
val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
7171
elements.size should be (3)
7272
elements(0) should be (((4, 2), struct2))
7373
elements(1) should be (((5, 3), struct3))
@@ -79,7 +79,7 @@ class PartitionedSerializedPairBufferSuite extends FunSuite {
7979
val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
8080
val struct = SomeStruct("something", 5)
8181
buffer.insert(4, 10, struct)
82-
val it = buffer.destructiveSortedWritablePartitionedIterator(null)
82+
val it = buffer.destructiveSortedWritablePartitionedIterator(None)
8383
val writer = new SimpleBlockObjectWriter
8484
assert(it.hasNext)
8585
it.nextPartition should be (4)
@@ -101,7 +101,7 @@ class PartitionedSerializedPairBufferSuite extends FunSuite {
101101
val struct3 = SomeStruct("something3", 10)
102102
buffer.insert(5, 3, struct3)
103103

104-
val it = buffer.destructiveSortedWritablePartitionedIterator(null)
104+
val it = buffer.destructiveSortedWritablePartitionedIterator(None)
105105
val writer = new SimpleBlockObjectWriter
106106
assert(it.hasNext)
107107
it.nextPartition should be (4)
@@ -142,7 +142,7 @@ class SimpleBlockObjectWriter extends BlockObjectWriter(null) {
142142
override def isOpen: Boolean = true
143143
override def commitAndClose(): Unit = { }
144144
override def revertPartialWritesAndClose(): Unit = { }
145-
override def fileSegment(): FileSegment = { null }
145+
override def fileSegment(): FileSegment = null
146146
override def write(key: Any, value: Any): Unit = { }
147147
override def recordWritten(): Unit = { }
148148
override def write(b: Int): Unit = { }

0 commit comments

Comments
 (0)