@@ -19,17 +19,20 @@ package org.apache.spark.sql.execution.python
1919
2020import scala .collection .mutable .ArrayBuffer
2121
22+ import org .mockito .Mockito .when
2223import org .scalatest .concurrent .Eventually
24+ import org .scalatest .mockito .MockitoSugar
2325import org .scalatest .time .SpanSugar ._
2426
2527import org .apache .spark ._
2628import org .apache .spark .memory .{TaskMemoryManager , TestMemoryManager }
29+ import org .apache .spark .serializer .{JavaSerializer , SerializerManager }
2730import org .apache .spark .sql .catalyst .expressions .{GenericInternalRow , UnsafeProjection }
2831import org .apache .spark .sql .execution .python .PythonForeachWriter .UnsafeRowBuffer
2932import org .apache .spark .sql .types .{DataType , IntegerType }
3033import org .apache .spark .util .Utils
3134
32- class PythonForeachWriterSuite extends SparkFunSuite with Eventually {
35+ class PythonForeachWriterSuite extends SparkFunSuite with Eventually with MockitoSugar {
3336
3437 testWithBuffer(" UnsafeRowBuffer: iterator blocks when no data is available" ) { b =>
3538 b.assertIteratorBlocked()
@@ -75,15 +78,20 @@ class PythonForeachWriterSuite extends SparkFunSuite with Eventually {
7578 tester = new BufferTester (memBytes, sleepPerRowReadMs)
7679 f(tester)
7780 } finally {
78- if (tester = = null ) tester.close()
81+ if (tester ! = null ) tester.close()
7982 }
8083 }
8184 }
8285
8386
8487 class BufferTester (memBytes : Long , sleepPerRowReadMs : Int ) {
8588 private val buffer = {
86- val mem = new TestMemoryManager (new SparkConf ())
89+ val mockEnv = mock[SparkEnv ]
90+ val conf = new SparkConf ()
91+ val serializerManager = new SerializerManager (new JavaSerializer (conf), conf, None )
92+ when(mockEnv.serializerManager).thenReturn(serializerManager)
93+ SparkEnv .set(mockEnv)
94+ val mem = new TestMemoryManager (conf)
8795 mem.limit(memBytes)
8896 val taskM = new TaskMemoryManager (mem, 0 )
8997 new UnsafeRowBuffer (taskM, Utils .createTempDir(), 1 )
0 commit comments