Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improve performance, address comments
  • Loading branch information
Davies Liu committed Mar 31, 2016
commit 8dc1adfb12a35280a01b4c8ab95b5aed346d8f0f
45 changes: 30 additions & 15 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,30 +77,42 @@ private[spark] case class PythonFunction(
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]])

/**
* A wrapper for chained Python functions (from bottom to top).
* @param funcs
*/
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])

object PythonRunner {
private[spark] object PythonRunner {
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
new PythonRunner(Seq(Seq(func)), bufferSize, reuse_worker, false, Seq(1))
new PythonRunner(
Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Seq(Seq(0)))
}
}

/**
* A helper class to run Python mapPartition/UDFs in Spark.
*
* funcs is a list of independent Python functions, each one of them is a list of chained Python
* functions (from bottom to top).
*/
private[spark] class PythonRunner(
funcs: Seq[Seq[PythonFunction]],
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuse_worker: Boolean,
isUDF: Boolean,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, do you mind adding scaldoc for these two new parameters?

numArgs: Seq[Int])
argOffsets: Seq[Seq[Int]])
extends Logging {

require(funcs.length == argOffsets.length, "numArgs should have the same length as funcs")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numArgs -> argOffsets


// All the Python functions should have the same exec, version and envvars.
private val envVars = funcs.head.head.envVars
private val pythonExec = funcs.head.head.pythonExec
private val pythonVer = funcs.head.head.pythonVer
private val envVars = funcs.head.funcs.head.envVars
private val pythonExec = funcs.head.funcs.head.pythonExec
private val pythonVer = funcs.head.funcs.head.pythonVer

private val accumulator = funcs.head.head.accumulator // TODO: support accumulator in multiple UDF
// TODO: support accumulator in multiple UDF
private val accumulator = funcs.head.funcs.head.accumulator

def compute(
inputIterator: Iterator[_],
Expand Down Expand Up @@ -240,8 +252,8 @@ private[spark] class PythonRunner(

@volatile private var _exception: Exception = null

private val pythonIncludes = funcs.flatMap(_.flatMap(_.pythonIncludes.asScala)).toSet
private val broadcastVars = funcs.flatMap(_.flatMap(_.broadcastVars.asScala))
private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))

setDaemon(true)

Expand Down Expand Up @@ -295,17 +307,20 @@ private[spark] class PythonRunner(
if (isUDF) {
dataOut.writeInt(1)
dataOut.writeInt(funcs.length)
funcs.zip(numArgs).foreach { case (fs, numArg) =>
dataOut.writeInt(numArg)
dataOut.writeInt(fs.length)
fs.foreach { f =>
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
dataOut.writeInt(offsets.length)
offsets.foreach { offset =>
dataOut.writeInt(offset)
}
dataOut.writeInt(chained.funcs.length)
chained.funcs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
}
}
} else {
dataOut.writeInt(0)
val command = funcs.head.head.command
val command = funcs.head.funcs.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
}
Expand Down
36 changes: 19 additions & 17 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@ def chain(f, g):


def wrap_udf(f, return_type):
return lambda *a: return_type.toInternal(f(*a))
toInternal = return_type.toInternal
return lambda *a: toInternal(f(*a))


def read_single_udf(pickleSer, infile):
num_arg = read_int(infile)
arg_offsets = [read_int(infile) for i in range(num_arg)]
row_func = None
for i in range(read_int(infile)):
f, return_type = read_command(pickleSer, infile)
Expand All @@ -76,27 +78,27 @@ def read_single_udf(pickleSer, infile):
else:
row_func = chain(row_func, f)
# the last returnType will be the return type of UDF
return num_arg, wrap_udf(row_func, return_type)
return arg_offsets, wrap_udf(row_func, return_type)


def read_udfs(pickleSer, infile):
num_udfs = read_int(infile)
udfs = []
offset = 0
for i in range(num_udfs):
num_arg, udf = read_single_udf(pickleSer, infile)
udfs.append((offset, offset + num_arg, udf))
offset += num_arg

if num_udfs == 1:
udf = udfs[0][2]

# fast path for single UDF
def mapper(args):
return udf(*args)
_, udf = read_single_udf(pickleSer, infile)
mapper = lambda a: udf(*a)
else:
def mapper(args):
return tuple(udf(*args[start:end]) for start, end, udf in udfs)
udfs = {}
call_udf = []
for i in range(num_udfs):
arg_offsets, udf = read_single_udf(pickleSer, infile)
udfs['f%d' % i] = udf
args = ["a[%d]" % o for o in arg_offsets]
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
# Create function like this:
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever! This is a neat trick.

mapper = eval(mapper_str, udfs)

func = lambda _, it: map(mapper, it)
ser = AutoBatchedSerializer(PickleSerializer())
Expand Down Expand Up @@ -149,8 +151,8 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)

_accumulatorRegistry.clear()
is_udf = read_int(infile)
if is_udf:
is_sql_udf = read_int(infile)
if is_sql_udf:
func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
else:
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{PythonFunction, PythonRunner}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonFunction, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.{DataType, StructField, StructType}


/**
Expand All @@ -45,15 +46,15 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c

def children: Seq[SparkPlan] = child :: Nil

private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
val (fs, children) = collectFunctions(u)
(fs ++ Seq(udf.func), children)
val (chained, children) = collectFunctions(u)
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
case children =>
// There should not be any other UDFs, or the children can't be evaluated directly.
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
(Seq(udf.func), udf.children)
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
}
}

Expand All @@ -69,30 +70,48 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c
// combine input with output from Python.
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()

val (pyFuncs, children) = udfs.map(collectFunctions).unzip
val numArgs = children.map(_.length)
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip

val pickle = new Pickler
// Most of the inputs are primitives, do not use memo for better performance
val pickle = new Pickler(false)
// flatten all the arguments
val allChildren = children.flatMap(x => x)
val currentRow = newMutableProjection(allChildren, child.output)()
val fields = allChildren.map(_.dataType)
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
val argOffsets = inputs.map { input =>
input.map { e =>
if (allInputs.exists(_.semanticEquals(e))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the worst-case this loop is N^2, but N is probably pretty small so it probably doesn't matter compared to other perf. issues impacting Python UDFs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

allInputs.indexWhere(_.semanticEquals(e))
} else {
allInputs += e
dataTypes += e.dataType
allInputs.length - 1
}
}
}
val projection = newMutableProjection(allInputs, child.output)()

// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
val inputIterator = iter.grouped(100).map { inputRows =>
val toBePickled = inputRows.map { row =>
queue.add(row)
EvaluatePython.toJava(currentRow(row), schema)
val inputIterator = iter.grouped(1024).map { inputRows =>
val toBePickled = inputRows.map { inputRow =>
queue.add(inputRow)
val row = projection(inputRow)
val fields = new Array[Any](row.numFields)
var i = 0
while (i < row.numFields) {
val dt = dataTypes(i)
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
i += 1
}
fields
}.toArray
pickle.dumps(toBePickled)
}

val context = TaskContext.get()

// Output iterator for results from Python.
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, numArgs)
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
Expand Down