Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
SPARK-25004: Add spark.executor.pyspark.memory limit.
  • Loading branch information
rdblue committed Aug 23, 2018
commit a5004badcea9527873e976a208a83abef2a73b66
14 changes: 13 additions & 1 deletion core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.PYSPARK_EXECUTOR_MEMORY
import org.apache.spark.rdd.RDD
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util._
Expand All @@ -52,6 +53,17 @@ private[spark] class PythonRDD(
val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true)

val memoryMb = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been awhile since I spent a lot of time thinking about how we launch our python worker processes. Maybe it would make sense to add a comment here explaining the logic a bit more? Based on the documentation in PythonWorkerFactory it appears we do the fork/not-fork decision not based on if reuseworker is set but instead on if we're in Windows or not. Is that the logic that this block was attempting to handle?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the comments below were clear: if a single worker is reused, it gets the entire allocation. If each core starts its own worker, each one gets an equal share.

If reuseWorker is actually ignored, then this needs to be updated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be a misunderstanding on what reuseWorker means perhaps. The workers will be reused but the decision on if we fork in Python or not is based on if we are in Windows or not. How about we both go and read the code path there and see if we reach the same understanding? I could be off too.

val allocation = conf.get(PYSPARK_EXECUTOR_MEMORY)
if (reuseWorker) {
// the shared python worker gets the entire allocation
allocation
} else {
// each python worker gets an equal part of the allocation
allocation.map(_ / conf.getInt("spark.executor.cores", 1))
}
}

override def getPartitions: Array[Partition] = firstParent.partitions

override val partitioner: Option[Partitioner] = {
Expand All @@ -61,7 +73,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val runner = PythonRunner(func, bufferSize, reuseWorker)
val runner = PythonRunner(func, bufferSize, reuseWorker, memoryMb)
runner.compute(firstParent.iterator(split, context), split.index, context)
}

Expand Down
19 changes: 14 additions & 5 deletions core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
bufferSize: Int,
reuseWorker: Boolean,
evalType: Int,
argOffsets: Array[Array[Int]])
argOffsets: Array[Array[Int]],
pythonMemoryMb: Option[Long])
extends Logging {

require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
Expand Down Expand Up @@ -95,6 +96,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
if (reuseWorker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
if (pythonMemoryMb.isDefined) {
envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", pythonMemoryMb.get.toString)
}
val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
// Whether is the worker released into idle pool
val released = new AtomicBoolean(false)
Expand Down Expand Up @@ -485,8 +489,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](

private[spark] object PythonRunner {

def apply(func: PythonFunction, bufferSize: Int, reuseWorker: Boolean): PythonRunner = {
new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuseWorker)
def apply(
func: PythonFunction,
bufferSize: Int,
reuseWorker: Boolean,
pyMemoryMb: Option[Long]): PythonRunner = {
new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuseWorker, pyMemoryMb)
}
}

Expand All @@ -496,9 +504,10 @@ private[spark] object PythonRunner {
private[spark] class PythonRunner(
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuseWorker: Boolean)
reuseWorker: Boolean,
pyMemoryMb: Option[Long])
extends BasePythonRunner[Array[Byte], Array[Byte]](
funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) {
funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0)), pyMemoryMb) {

protected override def newWriterThread(
env: SparkEnv,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ package object config {
.checkValue(_ >= 0, "The off-heap memory size must not be negative")
.createWithDefault(0)

private[spark] val PYSPARK_EXECUTOR_MEMORY = ConfigBuilder("spark.executor.pyspark.memory")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh, should have noticed this before. Should this be added to configuration.md?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should. I'll fix it.

.bytesConf(ByteUnit.MiB)
.createOptional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit: indentation ..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal()
.booleanConf.createWithDefault(false)

Expand Down
22 changes: 22 additions & 0 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import sys
import time
import resource
import socket
import traceback

Expand Down Expand Up @@ -263,6 +264,27 @@ def main(infile, outfile):
isBarrier = read_bool(infile)
boundPort = read_int(infile)
secret = UTF8Deserializer().loads(infile)

# set up memory limits
memory_limit_mb = int(os.environ.get('PYSPARK_EXECUTOR_MEMORY_MB', "-1"))
total_memory = resource.RLIMIT_AS
try:
(total_memory_limit, max_total_memory) = resource.getrlimit(total_memory)
msg = "Current mem: {0} of max {1}\n".format(total_memory_limit, max_total_memory)
sys.stderr.write()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forget to output msg here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


if memory_limit_mb > 0 and total_memory_limit < 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the logic of this block appears to be the user has requested a memory limit and Python does not have a memory limit set. If the user has requested a different memory limit than the one set though, regardless of if there is a current memory limit, would it make sense to set?

Also possible I've misunderstood the rlmit return values here.

That being said even if that is the behaviour we want, should we use resource.RLIM_INFINITY to check if its unlimited?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated to use resource.RLIM_INFINITY.

I think this should only set the resource limit if it isn't already set. It is unlikely that it's already set because this is during worker initialization, but the intent is to not cause harm if a higher-level system (i.e. container provider) has already set the limit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. What about if we only set the limit if it was lower than the current limit? (e.g. I could see a container system setting a limit based on an assumption which doesn't hold once Spark is in the mix and if we come up with a lower limit we could apply it)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me. I'll update this.

# convert to bytes
total_memory_limit = memory_limit_mb * 1024 * 1024

msg = "Setting mem to {0} of max {1}\n".format(total_memory_limit, max_total_memory)
sys.stderr.write(msg)
resource.setrlimit(total_memory, (total_memory_limit, total_memory_limit))

except (resource.error, OSError) as e:
# not all systems support resource limits, so warn instead of failing
sys.stderr.write("WARN: Failed to set memory limit: {0}\n".format(e))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

catch ValueError also in the case hard limit can't be set (if it's otherwise set)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


# initialize global state
taskContext = None
if isBarrier:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ private[spark] class Client(
private val executorMemoryOverhead = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse(
math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt

private val isPython = sparkConf.get(IS_PYTHON_APP)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting, my one concern here is probably a little esoteric, for mixed language pipelines this might not behave as desired. I'd suggest maybe a JIRA and a note in the config param that it only applies to Python apps not mixed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there documentation on how to create mixed-language pipelines? Clearly, all you need is a PythonRDD in your plan, but I thought it was non-trivial to create those from a Scala job.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true, creating mixed language pipelines is difficult and not documented. But I do it, and some others do as well. Some cloud providers (databricks is the most notable example) provide mixed language pipelines in their notebook solutions I believe, and so I think that also reaches a larger audience than the people who do it manually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not really documented but as Holden says, it exists. Livy does that - but Livy actually goes ahead and sets the internal spark.yarn.isPython property, so it would actually take advantage of this code...

Not sure how others do it, but all the ways I thought on how to expose this as an option were pretty hacky, so I think it's ok to leave things like this for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I'll add this to my example mixed pipeline repo so folks can see this hack.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk, can you point me to that repo? I'd love to have a look at how you do mixed pipelines.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private val pysparkWorkerMemory: Int = if (isPython) {
sparkConf.get(PYSPARK_EXECUTOR_MEMORY).map(_.toInt).getOrElse(0)
} else {
0
}

private val distCacheMgr = new ClientDistributedCacheManager()

private val principal = sparkConf.get(PRINCIPAL).orNull
Expand Down Expand Up @@ -333,7 +340,7 @@ private[spark] class Client(
val maxMem = newAppResponse.getMaximumResourceCapability().getMemory()
logInfo("Verifying our application has not requested more than the maximum " +
s"memory capability of the cluster ($maxMem MB per container)")
val executorMem = executorMemory + executorMemoryOverhead
val executorMem = executorMemory + executorMemoryOverhead + pysparkWorkerMemory
if (executorMem > maxMem) {
throw new IllegalArgumentException(s"Required executor memory ($executorMemory" +
s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add pysparkWorkerMemory here too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just switch it to use the total $executorMem instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like having it broken out so users can see where their allocation is going. Otherwise, users that only know about spark.executor.memory might not know how their allocation is 1gb higher when running PySpark. I've updated this to include the worker memory.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,17 @@ private[yarn] class YarnAllocator(
// Additional memory overhead.
protected val memoryOverhead: Int = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse(
math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)).toInt
protected val pysparkWorkerMemory: Int = if (sparkConf.get(IS_PYTHON_APP)) {
sparkConf.get(PYSPARK_EXECUTOR_MEMORY).map(_.toInt).getOrElse(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: default to -1 to be consistent?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or just use 0 in worker.py too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 in worker.py signals that it isn't set. Here, we use an Option instead. 0 is the correct size of the allocation to add to YARN resource requests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it

} else {
0
}
// Number of cores per executor.
protected val executorCores = sparkConf.get(EXECUTOR_CORES)
// Resource capability requested for each executors
private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores)
private[yarn] val resource = Resource.newInstance(
executorMemory + memoryOverhead + pysparkWorkerMemory,
executorCores)

private val launcherPool = ThreadUtils.newDaemonCachedThreadPool(
"ContainerLauncher", sparkConf.get(CONTAINER_LAUNCH_MAX_THREADS))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.internal.config.PYSPARK_EXECUTOR_MEMORY
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -81,6 +82,17 @@ case class AggregateInPandasExec(

val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
val memoryMb = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is minor, but this code block is repeated, would it make sense to factor out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other configuration options are already duplicated, so I was trying to make as few changes as possible.

Since there are several duplicated options, I think it makes more sense to pass the SparkConf through to PythonRunner so it can extract its own configuration.

@holdenk, would you like this refactor done in this PR, or should I do it in a follow-up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went ahead with the refactor.

val allocation = inputRDD.conf.get(PYSPARK_EXECUTOR_MEMORY)
if (reuseWorker) {
// the shared python worker gets the entire allocation
allocation
} else {
// each python worker gets an equal part of the allocation
allocation.map(_ / inputRDD.conf.getInt("spark.executor.cores", 1))
}
}

val sessionLocalTimeZone = conf.sessionLocalTimeZone
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)

Expand Down Expand Up @@ -139,6 +151,7 @@ case class AggregateInPandasExec(
pyFuncs,
bufferSize,
reuseWorker,
memoryMb,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
argOffsets,
aggInputSchema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuseWorker: Boolean,
pyMemoryMb: Option[Long],
argOffsets: Array[Array[Int]],
iter: Iterator[InternalRow],
schema: StructType,
Expand All @@ -84,6 +85,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
funcs,
bufferSize,
reuseWorker,
pyMemoryMb,
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
argOffsets,
schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ class ArrowPythonRunner(
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuseWorker: Boolean,
pyMemoryMb: Option[Long],
evalType: Int,
argOffsets: Array[Array[Int]],
schema: StructType,
timeZoneId: String,
conf: Map[String, String])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
funcs, bufferSize, reuseWorker, evalType, argOffsets) {
funcs, bufferSize, reuseWorker, evalType, argOffsets, pyMemoryMb) {

protected override def newWriterThread(
env: SparkEnv,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.internal.config.PYSPARK_EXECUTOR_MEMORY
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
Expand All @@ -38,6 +39,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuseWorker: Boolean,
pyMemoryMb: Option[Long],
argOffsets: Array[Array[Int]],
iter: Iterator[InternalRow],
schema: StructType,
Expand Down Expand Up @@ -69,7 +71,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi

// Output iterator for results from Python.
val outputIterator = new PythonUDFRunner(
funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets, pyMemoryMb)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.ChainedPythonFunctions
import org.apache.spark.internal.config.PYSPARK_EXECUTOR_MEMORY
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -80,6 +81,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuseWorker: Boolean,
pyMemoryMb: Option[Long],
argOffsets: Array[Array[Int]],
iter: Iterator[InternalRow],
schema: StructType,
Expand All @@ -89,6 +91,16 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil
val inputRDD = child.execute().map(_.copy())
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
val memoryMb = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same repeated code block as mentioned.

val allocation = inputRDD.conf.get(PYSPARK_EXECUTOR_MEMORY)
if (reuseWorker) {
// the shared python worker gets the entire allocation
allocation
} else {
// each python worker gets an equal part of the allocation
allocation.map(_ / inputRDD.conf.getInt("spark.executor.cores", 1))
}
}

inputRDD.mapPartitions { iter =>
val context = TaskContext.get()
Expand Down Expand Up @@ -129,7 +141,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil
}

val outputRowIterator = evaluate(
pyFuncs, bufferSize, reuseWorker, argOffsets, projectedRowIter, schema, context)
pyFuncs, bufferSize, reuseWorker, memoryMb, argOffsets, projectedRowIter, schema, context)

val joined = new JoinedRow
val resultProj = UnsafeProjection.create(output, output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.internal.config.PYSPARK_EXECUTOR_MEMORY
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -76,6 +77,16 @@ case class FlatMapGroupsInPandasExec(

val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
val memoryMb = {
val allocation = inputRDD.conf.get(PYSPARK_EXECUTOR_MEMORY)
if (reuseWorker) {
// the shared python worker gets the entire allocation
allocation
} else {
// each python worker gets an equal part of the allocation
allocation.map(_ / inputRDD.conf.getInt("spark.executor.cores", 1))
}
}
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
Expand Down Expand Up @@ -143,6 +154,7 @@ case class FlatMapGroupsInPandasExec(
chainedFunc,
bufferSize,
reuseWorker,
memoryMb,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
argOffsets,
dedupSchema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.concurrent.locks.ReentrantLock
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.PYSPARK_EXECUTOR_MEMORY
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
Expand All @@ -48,7 +49,17 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType)
val conf = SparkEnv.get.conf
val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true)
PythonRunner(func, bufferSize, reuseWorker)
val memoryMb = {
val allocation = conf.get(PYSPARK_EXECUTOR_MEMORY)
if (reuseWorker) {
// the shared python worker gets the entire allocation
allocation
} else {
// each python worker gets an equal part of the allocation
allocation.map(_ / conf.getInt("spark.executor.cores", 1))
}
}
PythonRunner(func, bufferSize, reuseWorker, memoryMb)
}

private lazy val outputIterator =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ class PythonUDFRunner(
bufferSize: Int,
reuseWorker: Boolean,
evalType: Int,
argOffsets: Array[Array[Int]])
argOffsets: Array[Array[Int]],
pyMemoryMb: Option[Long])
extends BasePythonRunner[Array[Byte], Array[Byte]](
funcs, bufferSize, reuseWorker, evalType, argOffsets) {
funcs, bufferSize, reuseWorker, evalType, argOffsets, pyMemoryMb) {

protected override def newWriterThread(
env: SparkEnv,
Expand Down
Loading