Skip to content
Prev Previous commit
Next Next commit
address comments
  • Loading branch information
Davies Liu committed Oct 29, 2015
commit afc8c7c9b0e92f9db6a5f72c14d8484a87311a51
42 changes: 29 additions & 13 deletions core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,28 @@
*/
public abstract class MemoryConsumer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about naming this class SpillableMemoryConsumer ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it too long?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The length is about the same as TaskMemoryManager - so not too long.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm neutral on the name change. At first I thought that the name MemoryConsumer might not make sense if it was used by places that can't spill, but I suppose that those places could just have spill() return 0. So I'm fine sticking with the current name.


private TaskMemoryManager memoryManager;
private long pageSize;
private final TaskMemoryManager taskMemoryManager;
private final long pageSize;
private long used;

protected MemoryConsumer(TaskMemoryManager memoryManager, long pageSize) {
this.memoryManager = memoryManager;
protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) {
this.taskMemoryManager = taskMemoryManager;
if (pageSize == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary since we have another constructor?

pageSize = taskMemoryManager.pageSizeBytes();
}
this.pageSize = pageSize;
this.used = 0;
}

protected MemoryConsumer(TaskMemoryManager taskMemoryManager) {
this(taskMemoryManager, taskMemoryManager.pageSizeBytes());
}

protected MemoryConsumer(TaskMemoryManager memoryManager) {
this(memoryManager, memoryManager.pageSizeBytes());
/**
* Returns the size of used memory in bytes.
*/
long getUsed() {
return used;
}

/**
Expand Down Expand Up @@ -70,19 +82,21 @@ public void spill() throws IOException {
* If there is not enough memory, throws OutOfMemoryError.
*/
protected void acquireMemory(long size) {
long got = memoryManager.acquireExecutionMemory(size, this);
long got = taskMemoryManager.acquireExecutionMemory(size, this);
if (got < size) {
memoryManager.releaseExecutionMemory(got, this);
memoryManager.showMemoryUsage();
taskMemoryManager.releaseExecutionMemory(got, this);
taskMemoryManager.showMemoryUsage();
throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got);
}
used += got;
}

/**
* Release `size` bytes memory.
*/
protected void releaseMemory(long size) {
memoryManager.releaseExecutionMemory(size, this);
taskMemoryManager.releaseExecutionMemory(size, this);
used -= size;
}

/**
Expand All @@ -93,23 +107,25 @@ protected void releaseMemory(long size) {
* @throws OutOfMemoryError
*/
protected MemoryBlock allocatePage(long required) {
MemoryBlock page = memoryManager.allocatePage(Math.max(pageSize, required), this);
MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this);
if (page == null || page.size() < required) {
long got = 0;
if (page != null) {
got = page.size();
freePage(page);
}
memoryManager.showMemoryUsage();
taskMemoryManager.showMemoryUsage();
throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got);
}
used += page.size();
return page;
}

/**
* Free a memory block.
*/
protected void freePage(MemoryBlock page) {
memoryManager.freePage(page, this);
taskMemoryManager.freePage(page, this);
used -= page.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, maybe an invalid concern, but is it safe to call page.size() on a freed page?

}
}
87 changes: 15 additions & 72 deletions core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;

import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.SparkException;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.Utils;

Expand Down Expand Up @@ -109,7 +107,7 @@ public class TaskMemoryManager {
/**
* The size of memory granted to each consumer.
*/
private final HashMap<MemoryConsumer, Long> consumers;
private final HashSet<MemoryConsumer> consumers;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment to explain that this field is guarded by synchronizing on this (or use @GuardedBy("this")).


/**
* Construct a new TaskMemoryManager.
Expand All @@ -118,7 +116,7 @@ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) {
this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap();
this.memoryManager = memoryManager;
this.taskAttemptId = taskAttemptId;
this.consumers = new HashMap<>();
this.consumers = new HashSet<>();
}

/**
Expand All @@ -135,12 +133,9 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
// try to release memory from other consumers first, then we can reduce the frequency of
// spilling, avoid to have too many spilled files.
if (got < required) {
// consumers could be modified by spill(), so we should have a copy here.
MemoryConsumer[] cs = new MemoryConsumer[consumers.size()];
consumers.keySet().toArray(cs);
// Call spill() on other consumers to release memory
for (MemoryConsumer c: cs) {
if (c != null && c != consumer) {
for (MemoryConsumer c: consumers) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this approach still have the same concern about concurrent modification of consumers while iterating over it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we never remove it, and it will not add more under this lock.

if (c != null && c != consumer && c.getUsed() > 0) {
try {
long released = c.spill(required - got, consumer);
if (released > 0) {
Expand Down Expand Up @@ -176,15 +171,7 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
}
}

// Update the accounting, even consumer is null
if (got > 0) {
long old = 0L;
if (consumers.containsKey(consumer)) {
old = consumers.get(consumer);
}
consumers.put(consumer, got + old);
}

consumers.add(consumer);
logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer);
return got;
}
Expand All @@ -194,67 +181,20 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
* Release N bytes of execution memory for a MemoryConsumer.
*/
public void releaseExecutionMemory(long size, MemoryConsumer consumer) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an assert to make sure size >= 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

assert(size >= 0);
if (size == 0) {
return;
}
synchronized (this) {
if (consumers.containsKey(consumer)) {
long old = consumers.get(consumer);
if (old > size) {
consumers.put(consumer, old - size);
} else {
if (old < size) {
String msg = "Release " + size + " bytes memory (more than acquired " + old + ") for "
+ consumer;
logger.warn(msg);
if (Utils.isTesting()) {
Platform.throwException(new SparkException(msg));
}
}
consumers.remove(consumer);
}
} else {
String msg = "Release " + size + " bytes memory for non-existent " + consumer;
logger.warn(msg);
if (Utils.isTesting()) {
Platform.throwException(new SparkException(msg));
}
}
}

logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer);
memoryManager.releaseExecutionMemory(size, taskAttemptId);
}

public void transferOwnership(long size, MemoryConsumer from, MemoryConsumer to) {
assert(size >= 0);
synchronized (this) {
if (consumers.containsKey(from)) {
long old = consumers.get(from);
if (old > size) {
consumers.put(from, old - size);
} else {
consumers.remove(from);
}
if (consumers.containsKey(to)) {
old = consumers.get(to);
} else {
old = 0L;
}
consumers.put(to, old + size);
}
}
}

/**
* Dump the memory usage of all consumers.
*/
public void showMemoryUsage() {
logger.info("Memory used in task " + taskAttemptId);
synchronized (this) {
for (MemoryConsumer c: consumers.keySet()) {
logger.info("Acquired by " + c + ": " + Utils.bytesToString(consumers.get(c)));
for (MemoryConsumer c: consumers) {
if (c.getUsed() > 0) {
logger.info("Acquired by " + c + ": " + Utils.bytesToString(c.getUsed()));
}
}
}
}
Expand Down Expand Up @@ -399,8 +339,11 @@ public long getOffsetInPage(long pagePlusOffsetAddress) {
public long cleanUpAllAllocatedMemory() {
synchronized (this) {
Arrays.fill(pageTable, null);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

for (MemoryConsumer c: consumers.keySet()) {
logger.warn("leak " + Utils.bytesToString(consumers.get(c)) + " memory from " + c);
for (MemoryConsumer c: consumers) {
if (c != null && c.getUsed() > 0) {
// In case of failed task, it's normal to see leaked memory
logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c);
}
}
consumers.clear();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ public ShuffleExternalSorter(
int numPartitions,
SparkConf conf,
ShuffleWriteMetrics writeMetrics) {
super(memoryManager);
super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES,
memoryManager.pageSizeBytes()));
this.taskMemoryManager = memoryManager;
this.blockManager = blockManager;
this.taskContext = taskContext;
this.peakMemoryUsedBytes = initialSize;
this.numPartitions = numPartitions;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
Expand All @@ -116,6 +116,7 @@ public ShuffleExternalSorter(
this.writeMetrics = writeMetrics;
acquireMemory(initialSize * 8L);
this.inMemSorter = new ShuffleInMemorySorter(initialSize);
this.peakMemoryUsedBytes = getMemoryUsage();
}

/**
Expand Down Expand Up @@ -372,6 +373,7 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p
}

growPointerArrayIfNecessary();
// Need 4 bytes to store the record length.
final int required = length + 4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add back the now-missing comment that says "Need 4 bytes to store the record length."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

acquireNewPageIfNecessary(required);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ private UnsafeExternalSorter(
this.inMemSorter = existingInMemorySorter;
// will acquire after free the map
}
this.peakMemoryUsedBytes = getMemoryUsage();

// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,23 @@
public class TaskMemoryManagerSuite {

class TestMemoryConsumer extends MemoryConsumer {
volatile long used = 0L;

TestMemoryConsumer(TaskMemoryManager memoryManager) {
super(memoryManager);
}

@Override
public long spill(long size, MemoryConsumer trigger) throws IOException {
long used = getUsed();
releaseMemory(used);
long released = used;
used = 0;
return released;
return used;
}

void use(long size) {
acquireMemory(size);
used += size;
}

void free(long size) {
releaseMemory(size);
used -= size;
}
}

Expand Down Expand Up @@ -93,33 +88,33 @@ public void cooperativeSpilling() {
TestMemoryConsumer c1 = new TestMemoryConsumer(manager);
TestMemoryConsumer c2 = new TestMemoryConsumer(manager);
c1.use(100);
assert(c1.used == 100);
assert(c1.getUsed() == 100);
c2.use(100);
assert(c2.used == 100);
assert(c1.used == 0); // spilled
assert(c2.getUsed() == 100);
assert(c1.getUsed() == 0); // spilled
c1.use(100);
assert(c1.used == 100);
assert(c2.used == 0); // spilled
assert(c1.getUsed() == 100);
assert(c2.getUsed() == 0); // spilled

c1.use(50);
assert(c1.used == 50); // spilled
assert(c2.used == 0);
assert(c1.getUsed() == 50); // spilled
assert(c2.getUsed() == 0);
c2.use(50);
assert(c1.used == 50);
assert(c2.used == 50);
assert(c1.getUsed() == 50);
assert(c2.getUsed() == 50);

c1.use(100);
assert(c1.used == 100);
assert(c2.used == 0); // spilled
assert(c1.getUsed() == 100);
assert(c2.getUsed() == 0); // spilled

c1.free(20);
assert(c1.used == 80);
assert(c1.getUsed() == 80);
c2.use(10);
assert(c1.used == 80);
assert(c2.used == 10);
assert(c1.getUsed() == 80);
assert(c2.getUsed() == 10);
c2.use(100);
assert(c2.used == 100);
assert(c1.used == 0); // spilled
assert(c2.getUsed() == 100);
assert(c1.getUsed() == 0); // spilled

c1.free(0);
c2.free(100);
Expand Down