@@ -24,9 +24,12 @@ import org.mockito.{Mock, MockitoAnnotations}
2424import org .mockito .Answers .RETURNS_SMART_NULLS
2525import org .mockito .ArgumentMatchers .any
2626import org .mockito .Mockito .when
27+ import org .mockito .invocation .InvocationOnMock
28+ import org .mockito .stubbing .Answer
2729import scala .util .Random
2830
2931import org .apache .spark .{Aggregator , MapOutputTracker , ShuffleDependency , SparkConf , SparkEnv , TaskContext }
32+ import org .apache .spark .api .shuffle .ShuffleLocation
3033import org .apache .spark .benchmark .{Benchmark , BenchmarkBase }
3134import org .apache .spark .executor .TaskMetrics
3235import org .apache .spark .memory .{TaskMemoryManager , TestMemoryManager }
@@ -194,14 +197,17 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase {
194197 }
195198
196199 when(mapOutputTracker.getMapSizesByShuffleLocation(0 , 0 , 1 ))
197- .thenReturn {
198- val shuffleBlockIdsAndSizes = (0 until NUM_MAPS ).map { mapId =>
199- val shuffleBlockId = ShuffleBlockId (0 , mapId, 0 )
200- (shuffleBlockId, dataFileLength)
200+ .thenAnswer(new Answer [Iterator [(Option [ShuffleLocation ], Seq [(BlockId , Long )])]] {
201+ def answer (invocationOnMock : InvocationOnMock ):
202+ Iterator [(Option [ShuffleLocation ], Seq [(BlockId , Long )])] = {
203+ val shuffleBlockIdsAndSizes = (0 until NUM_MAPS ).map { mapId =>
204+ val shuffleBlockId = ShuffleBlockId (0 , mapId, 0 )
205+ (shuffleBlockId, dataFileLength)
206+ }
207+ Seq ((Option .apply(DefaultMapShuffleLocations .get(dataBlockId)), shuffleBlockIdsAndSizes))
208+ .toIterator
201209 }
202- Seq ((Option .apply(DefaultMapShuffleLocations .get(dataBlockId)), shuffleBlockIdsAndSizes))
203- .toIterator
204- }
210+ })
205211
206212 when(dependency.serializer).thenReturn(serializer)
207213 when(dependency.aggregator).thenReturn(aggregator)
0 commit comments