Skip to content

Commit ee6b9a4

Browse files
author
Davies Liu
committed
fix build
1 parent d0ada7b commit ee6b9a4

File tree

9 files changed

+93
-106
lines changed

9 files changed

+93
-106
lines changed

core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java

Lines changed: 35 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
import org.slf4j.Logger;
2626
import org.slf4j.LoggerFactory;
2727

28+
import org.apache.spark.SparkException;
29+
import org.apache.spark.unsafe.Platform;
2830
import org.apache.spark.unsafe.memory.MemoryBlock;
31+
import org.apache.spark.util.Utils;
2932

3033
/**
3134
* Manages the memory allocated by an individual task.
@@ -117,14 +120,6 @@ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) {
117120
this.consumers = new HashMap<>();
118121
}
119122

120-
/**
121-
* Acquire N bytes of memory for execution, evicting cached blocks if necessary.
122-
* @return number of bytes successfully granted (<= N).
123-
*/
124-
public long acquireExecutionMemory(long size) {
125-
return memoryManager.acquireExecutionMemory(size, taskAttemptId);
126-
}
127-
128123
/**
129124
* Acquire N bytes of memory for a consumer. If there is no enough memory, it will call
130125
* spill() of consumers to release more memory.
@@ -133,26 +128,26 @@ public long acquireExecutionMemory(long size) {
133128
*/
134129
public long acquireExecutionMemory(long size, MemoryConsumer consumer) throws IOException {
135130
synchronized (this) {
136-
long got = acquireExecutionMemory(size);
131+
long got = memoryManager.acquireExecutionMemory(size, taskAttemptId);
137132

133+
// call spill() on itself to release some memory
138134
if (got < size && consumer != null) {
139-
// call spill() on itself to release some memory
140135
consumer.spill(size - got);
141-
got += acquireExecutionMemory(size - got);
142-
143-
if (got < size) {
144-
long needed = size - got;
145-
// call spill() on other consumers to release memory
146-
for (MemoryConsumer c: consumers.keySet()) {
147-
if (c != consumer) {
148-
needed -= c.spill(size - got);
149-
if (needed < 0) {
150-
break;
151-
}
136+
got += memoryManager.acquireExecutionMemory(size - got, taskAttemptId);
137+
}
138+
139+
if (got < size) {
140+
long needed = size - got;
141+
// call spill() on other consumers to release memory
142+
for (MemoryConsumer c: consumers.keySet()) {
143+
if (c != null && c != consumer) {
144+
needed -= c.spill(size - got);
145+
if (needed < 0) {
146+
break;
152147
}
153148
}
154-
got += acquireExecutionMemory(size - got);
155149
}
150+
got += memoryManager.acquireExecutionMemory(size - got, taskAttemptId);
156151
}
157152

158153
long old = 0L;
@@ -165,30 +160,35 @@ public long acquireExecutionMemory(long size, MemoryConsumer consumer) throws IO
165160
}
166161
}
167162

168-
/**
169-
* Release N bytes of execution memory.
170-
*/
171-
public void releaseExecutionMemory(long size) {
172-
memoryManager.releaseExecutionMemory(size, taskAttemptId);
173-
}
174-
175163
/**
176164
* Release N bytes of execution memory for a MemoryConsumer.
177165
*/
178166
public void releaseExecutionMemory(long size, MemoryConsumer consumer) {
179167
synchronized (this) {
180-
if (consumer != null && consumers.containsKey(consumer)) {
168+
if (consumers.containsKey(consumer)) {
181169
long old = consumers.get(consumer);
182170
if (old > size) {
183171
consumers.put(consumer, old - size);
184172
} else {
185173
if (old < size) {
186-
// TODO
174+
if (Utils.isTesting()) {
175+
Platform.throwException(
176+
new SparkException("Release more memory " + size + "than acquired " + old + " for "
177+
+ consumer));
178+
} else {
179+
logger.warn("Release more memory " + size + " than acquired " + old + "for "
180+
+ consumer);
181+
}
187182
}
188183
consumers.remove(consumer);
189184
}
190185
} else {
191-
// TODO
186+
if (Utils.isTesting()) {
187+
Platform.throwException(
188+
new SparkException("Release memory " + size + " for not existed " + consumer));
189+
} else {
190+
logger.warn("Release memory " + size + " for not existed " + consumer);
191+
}
192192
}
193193
memoryManager.releaseExecutionMemory(size, taskAttemptId);
194194
}
@@ -198,16 +198,6 @@ public long pageSizeBytes() {
198198
return memoryManager.pageSizeBytes();
199199
}
200200

201-
/**
202-
* Allocate a block of memory that will be tracked in the MemoryManager's page table; this is
203-
* intended for allocating large blocks of Tungsten memory that will be shared between operators.
204-
*
205-
* Returns `null` if there was not enough memory to allocate the page.
206-
*/
207-
public MemoryBlock allocatePage(long size) throws IOException {
208-
return allocatePage(size, null);
209-
}
210-
211201
/**
212202
* Allocate a block of memory that will be tracked in the MemoryManager's page table; this is
213203
* intended for allocating large blocks of Tungsten memory that will be shared between operators.
@@ -244,14 +234,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) throws IOExc
244234
}
245235

246236
/**
247-
* Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
248-
*/
249-
public void freePage(MemoryBlock page) {
250-
freePage(page, null);
251-
}
252-
253-
/**
254-
* Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
237+
* Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}.
255238
*/
256239
public void freePage(MemoryBlock page, MemoryConsumer consumer) {
257240
assert (page.pageNumber != -1) :
@@ -273,7 +256,7 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) {
273256
* Given a memory page and offset within that page, encode this address into a 64-bit long.
274257
* This address will remain valid as long as the corresponding page has not been freed.
275258
*
276-
* @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/
259+
* @param page a data page allocated by {@link TaskMemoryManager#allocatePage}/
277260
* @param offsetInPage an offset in this page which incorporates the base offset. In other words,
278261
* this should be the value that you would pass as the base offset into an
279262
* UNSAFE call (e.g. page.baseOffset() + something).
@@ -349,7 +332,7 @@ public long cleanUpAllAllocatedMemory() {
349332
for (MemoryBlock page : pageTable) {
350333
if (page != null) {
351334
freedBytes += page.size();
352-
freePage(page);
335+
freePage(page, null);
353336
}
354337
}
355338

core/src/main/scala/org/apache/spark/util/collection/Spillable.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ private[spark] trait Spillable[C] extends Logging {
7878
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
7979
// Claim up to double our current memory from the shuffle memory pool
8080
val amountToRequest = 2 * currentMemory - myMemoryThreshold
81-
val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest)
81+
val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest, null)
8282
myMemoryThreshold += granted
8383
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
8484
// or we already had more memory than myMemoryThreshold), spill the current collection
@@ -107,7 +107,7 @@ private[spark] trait Spillable[C] extends Logging {
107107
*/
108108
def releaseMemory(): Unit = {
109109
// The amount we requested does not include the initial memory tracking threshold
110-
taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold)
110+
taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold, null)
111111
myMemoryThreshold = initialMemoryThreshold
112112
}
113113

core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.memory;
1919

20+
import java.io.IOException;
21+
2022
import org.junit.Assert;
2123
import org.junit.Test;
2224

@@ -26,18 +28,18 @@
2628
public class TaskMemoryManagerSuite {
2729

2830
@Test
29-
public void leakedPageMemoryIsDetected() {
31+
public void leakedPageMemoryIsDetected() throws IOException {
3032
final TaskMemoryManager manager = new TaskMemoryManager(
3133
new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
32-
manager.allocatePage(4096); // leak memory
34+
manager.allocatePage(4096, null); // leak memory
3335
Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
3436
}
3537

3638
@Test
37-
public void encodePageNumberAndOffsetOffHeap() {
39+
public void encodePageNumberAndOffsetOffHeap() throws IOException {
3840
final TaskMemoryManager manager = new TaskMemoryManager(
3941
new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0);
40-
final MemoryBlock dataPage = manager.allocatePage(256);
42+
final MemoryBlock dataPage = manager.allocatePage(256, null);
4143
// In off-heap mode, an offset is an absolute address that may require more than 51 bits to
4244
// encode. This test exercises that corner-case:
4345
final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
@@ -47,10 +49,10 @@ public void encodePageNumberAndOffsetOffHeap() {
4749
}
4850

4951
@Test
50-
public void encodePageNumberAndOffsetOnHeap() {
52+
public void encodePageNumberAndOffsetOnHeap() throws IOException {
5153
final TaskMemoryManager manager = new TaskMemoryManager(
5254
new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
53-
final MemoryBlock dataPage = manager.allocatePage(256);
55+
final MemoryBlock dataPage = manager.allocatePage(256, null);
5456
final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
5557
Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
5658
Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));

core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ public void heap() throws IOException {
3838
final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
3939
final TaskMemoryManager memoryManager =
4040
new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
41-
final MemoryBlock page0 = memoryManager.allocatePage(128);
42-
final MemoryBlock page1 = memoryManager.allocatePage(128);
41+
final MemoryBlock page0 = memoryManager.allocatePage(128, null);
42+
final MemoryBlock page1 = memoryManager.allocatePage(128, null);
4343
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
4444
page1.getBaseOffset() + 42);
4545
PackedRecordPointer packedPointer = new PackedRecordPointer();
@@ -57,8 +57,8 @@ public void offHeap() throws IOException {
5757
final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true");
5858
final TaskMemoryManager memoryManager =
5959
new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
60-
final MemoryBlock page0 = memoryManager.allocatePage(128);
61-
final MemoryBlock page1 = memoryManager.allocatePage(128);
60+
final MemoryBlock page0 = memoryManager.allocatePage(128, null);
61+
final MemoryBlock page1 = memoryManager.allocatePage(128, null);
6262
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
6363
page1.getBaseOffset() + 42);
6464
PackedRecordPointer packedPointer = new PackedRecordPointer();

core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public void testBasicSorting() throws Exception {
6161
final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
6262
final TaskMemoryManager memoryManager =
6363
new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
64-
final MemoryBlock dataPage = memoryManager.allocatePage(2048);
64+
final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
6565
final Object baseObject = dataPage.getBaseObject();
6666
final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
6767
final HashPartitioner hashPartitioner = new HashPartitioner(4);

core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ public void writeEnoughDataToTriggerSpill() throws Exception {
402402
.doCallRealMethod() // allocate initial data page
403403
.doReturn(0L) // deny request to allocate new page
404404
.doCallRealMethod() // grant new sort buffer and data page
405-
.when(taskMemoryManager).acquireExecutionMemory(anyLong());
405+
.when(taskMemoryManager).acquireExecutionMemory(anyLong(), null);
406406
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
407407
final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
408408
final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
@@ -430,7 +430,7 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exce
430430
.doCallRealMethod() // allocate initial data page
431431
.doReturn(0L) // deny request to allocate new page
432432
.doCallRealMethod() // grant new sort buffer and data page
433-
.when(taskMemoryManager).acquireExecutionMemory(anyLong());
433+
.when(taskMemoryManager).acquireExecutionMemory(anyLong(), null);
434434
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
435435
final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
436436
for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {

core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,19 @@
2020
import java.util.Arrays;
2121

2222
import org.junit.Test;
23-
import static org.hamcrest.MatcherAssert.assertThat;
24-
import static org.hamcrest.Matchers.*;
25-
import static org.junit.Assert.*;
26-
import static org.mockito.Mockito.mock;
2723

2824
import org.apache.spark.HashPartitioner;
2925
import org.apache.spark.SparkConf;
30-
import org.apache.spark.unsafe.Platform;
3126
import org.apache.spark.memory.GrantEverythingMemoryManager;
32-
import org.apache.spark.unsafe.memory.MemoryBlock;
3327
import org.apache.spark.memory.TaskMemoryManager;
28+
import org.apache.spark.unsafe.Platform;
29+
import org.apache.spark.unsafe.memory.MemoryBlock;
30+
31+
import static org.hamcrest.MatcherAssert.assertThat;
32+
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
33+
import static org.hamcrest.Matchers.isIn;
34+
import static org.junit.Assert.assertEquals;
35+
import static org.mockito.Mockito.mock;
3436

3537
public class UnsafeInMemorySorterSuite {
3638

@@ -67,7 +69,7 @@ public void testSortingOnlyByIntegerPrefix() throws Exception {
6769
};
6870
final TaskMemoryManager memoryManager = new TaskMemoryManager(
6971
new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
70-
final MemoryBlock dataPage = memoryManager.allocatePage(2048);
72+
final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
7173
final Object baseObject = dataPage.getBaseObject();
7274
// Write the records into the data page:
7375
long position = dataPage.getBaseOffset();

core/src/test/scala/org/apache/spark/FailureSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
149149
// cause is preserved
150150
val thrownDueToTaskFailure = intercept[SparkException] {
151151
sc.parallelize(Seq(0)).mapPartitions { iter =>
152-
TaskContext.get().taskMemoryManager().allocatePage(128)
152+
TaskContext.get().taskMemoryManager().allocatePage(128, null)
153153
throw new Exception("intentional task failure")
154154
iter
155155
}.count()
@@ -159,7 +159,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
159159
// If the task succeeded but memory was leaked, then the task should fail due to that leak
160160
val thrownDueToMemoryLeak = intercept[SparkException] {
161161
sc.parallelize(Seq(0)).mapPartitions { iter =>
162-
TaskContext.get().taskMemoryManager().allocatePage(128)
162+
TaskContext.get().taskMemoryManager().allocatePage(128, null)
163163
iter
164164
}.count()
165165
}

0 commit comments

Comments
 (0)