Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ private[spark] class PythonRDD(

override def run(): Unit = Utils.logUncaughtExceptions {
try {
TaskContext.setTaskContext(context)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
Expand Down Expand Up @@ -263,11 +264,6 @@ private[spark] class PythonRDD(
if (!worker.isClosed) {
Utils.tryLog(worker.shutdownOutput())
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/cc @davies for this PySpark change.

} finally {
// Release memory used by this thread for shuffles
env.shuffleMemoryManager.releaseMemoryForThisThread()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/api/r/RRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,15 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
partition: Int): Unit = {

val env = SparkEnv.get
val taskContext = TaskContext.get()
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val stream = new BufferedOutputStream(output, bufferSize)

new Thread("writer for R") {
override def run(): Unit = {
try {
SparkEnv.set(env)
TaskContext.setTaskContext(taskContext)
val dataOut = new DataOutputStream(stream)
dataOut.writeInt(partition)

Expand Down
4 changes: 0 additions & 4 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,6 @@ private[spark] class Executor(
}

} finally {
// Release memory used by this thread for shuffles
env.shuffleMemoryManager.releaseMemoryForThisThread()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
runningTasks.remove(taskId)
}
}
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ private[spark] class PipedRDD[T: ClassTag](
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + command) {
override def run() {
TaskContext.setTaskContext(context)
val out = new PrintWriter(proc.getOutputStream)

// scalastyle:off println
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.HashMap

import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.unsafe.memory.TaskMemoryManager
Expand Down Expand Up @@ -86,7 +86,18 @@ private[spark] abstract class Task[T](
(runTask(context), context.collectAccumulators())
} finally {
context.markTaskCompleted()
TaskContext.unset()
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for shuffles
SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask()
}
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
}
} finally {
TaskContext.unset()
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,95 +19,101 @@ package org.apache.spark.shuffle

import scala.collection.mutable

import org.apache.spark.{Logging, SparkException, SparkConf}
import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext}

/**
* Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling
* Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling
* collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory
* from this pool and release it as it spills data out. When a task ends, all its memory will be
* released by the Executor.
*
* This class tries to ensure that each thread gets a reasonable share of memory, instead of some
* thread ramping up to a large amount first and then causing others to spill to disk repeatedly.
* If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory
* This class tries to ensure that each task gets a reasonable share of memory, instead of some
* task ramping up to a large amount first and then causing others to spill to disk repeatedly.
* If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory
* before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
* set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever
* set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever
* this set changes. This is all done by synchronizing access on "this" to mutate state and using
* wait() and notifyAll() to signal changes.
*/
private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes
private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes

def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf))

private def currentTaskAttemptId(): Long = {
// In case this is called on the driver, return an invalid task attempt id.
Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a line explaining why the default value is needed

}

/**
* Try to acquire up to numBytes memory for the current thread, and return the number of bytes
* Try to acquire up to numBytes memory for the current task, and return the number of bytes
* obtained, or 0 if none can be allocated. This call may block until there is enough free memory
* in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the
* total memory pool (where N is the # of active threads) before it is forced to spill. This can
* happen if the number of threads increases but an older thread had a lot of memory already.
* in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the
* total memory pool (where N is the # of active tasks) before it is forced to spill. This can
* happen if the number of tasks increases but an older task had a lot of memory already.
*/
def tryToAcquire(numBytes: Long): Long = synchronized {
val threadId = Thread.currentThread().getId
val taskAttemptId = currentTaskAttemptId()
assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)

// Add this thread to the threadMemory map just so we can keep an accurate count of the number
// of active threads, to let other threads ramp down their memory in calls to tryToAcquire
if (!threadMemory.contains(threadId)) {
threadMemory(threadId) = 0L
notifyAll() // Will later cause waiting threads to wake up and check numThreads again
// Add this task to the taskMemory map just so we can keep an accurate count of the number
// of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire
if (!taskMemory.contains(taskAttemptId)) {
taskMemory(taskAttemptId) = 0L
notifyAll() // Will later cause waiting tasks to wake up and check numThreads again
}

// Keep looping until we're either sure that we don't want to grant this request (because this
// thread would have more than 1 / numActiveThreads of the memory) or we have enough free
// memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)).
// task would have more than 1 / numActiveTasks of the memory) or we have enough free
// memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
while (true) {
val numActiveThreads = threadMemory.keys.size
val curMem = threadMemory(threadId)
val freeMemory = maxMemory - threadMemory.values.sum
val numActiveTasks = taskMemory.keys.size
val curMem = taskMemory(taskAttemptId)
val freeMemory = maxMemory - taskMemory.values.sum

// How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads;
// How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
// don't let it be negative
val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem))
val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem))

if (curMem < maxMemory / (2 * numActiveThreads)) {
// We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking;
// if we can't give it this much now, wait for other threads to free up memory
// (this happens if older threads allocated lots of memory before N grew)
if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) {
if (curMem < maxMemory / (2 * numActiveTasks)) {
// We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
// if we can't give it this much now, wait for other tasks to free up memory
// (this happens if older tasks allocated lots of memory before N grew)
if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) {
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
taskMemory(taskAttemptId) += toGrant
return toGrant
} else {
logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free")
logInfo(
s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
wait()
}
} else {
// Only give it as much memory as is free, which might be none if it reached 1 / numThreads
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
taskMemory(taskAttemptId) += toGrant
return toGrant
}
}
0L // Never reached
}

/** Release numBytes bytes for the current thread. */
/** Release numBytes bytes for the current task. */
def release(numBytes: Long): Unit = synchronized {
val threadId = Thread.currentThread().getId
val curMem = threadMemory.getOrElse(threadId, 0L)
val taskAttemptId = currentTaskAttemptId()
val curMem = taskMemory.getOrElse(taskAttemptId, 0L)
if (curMem < numBytes) {
throw new SparkException(
s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}")
s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}")
}
threadMemory(threadId) -= numBytes
taskMemory(taskAttemptId) -= numBytes
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}

/** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */
def releaseMemoryForThisThread(): Unit = synchronized {
val threadId = Thread.currentThread().getId
threadMemory.remove(threadId)
/** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */
def releaseMemoryForThisTask(): Unit = synchronized {
val taskAttemptId = currentTaskAttemptId()
taskMemory.remove(taskAttemptId)
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}
}
Expand Down
Loading