Skip to content

Commit 3a760e7

Browse files
yifeihbulldozer-bot[bot]
authored andcommitted
[SPARK-25299] fix reader benchmarks (apache#544)
Fix the stubbing of the reader benchmark tests
1 parent b35d238 commit 3a760e7

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@ import org.mockito.{Mock, MockitoAnnotations}
2424
import org.mockito.Answers.RETURNS_SMART_NULLS
2525
import org.mockito.ArgumentMatchers.any
2626
import org.mockito.Mockito.when
27+
import org.mockito.invocation.InvocationOnMock
28+
import org.mockito.stubbing.Answer
2729
import scala.util.Random
2830

2931
import org.apache.spark.{Aggregator, MapOutputTracker, ShuffleDependency, SparkConf, SparkEnv, TaskContext}
32+
import org.apache.spark.api.shuffle.ShuffleLocation
3033
import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
3134
import org.apache.spark.executor.TaskMetrics
3235
import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
@@ -194,14 +197,17 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase {
194197
}
195198

196199
when(mapOutputTracker.getMapSizesByShuffleLocation(0, 0, 1))
197-
.thenReturn {
198-
val shuffleBlockIdsAndSizes = (0 until NUM_MAPS).map { mapId =>
199-
val shuffleBlockId = ShuffleBlockId(0, mapId, 0)
200-
(shuffleBlockId, dataFileLength)
200+
.thenAnswer(new Answer[Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]] {
201+
def answer(invocationOnMock: InvocationOnMock):
202+
Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
203+
val shuffleBlockIdsAndSizes = (0 until NUM_MAPS).map { mapId =>
204+
val shuffleBlockId = ShuffleBlockId(0, mapId, 0)
205+
(shuffleBlockId, dataFileLength)
206+
}
207+
Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes))
208+
.toIterator
201209
}
202-
Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes))
203-
.toIterator
204-
}
210+
})
205211

206212
when(dependency.serializer).thenReturn(serializer)
207213
when(dependency.aggregator).thenReturn(aggregator)

0 commit comments

Comments
 (0)