Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
multiple Python UDFs in single batch
  • Loading branch information
Davies Liu committed Mar 30, 2016
commit f6b737337dc087251899c33e71ef5cdf89f0c5a3
49 changes: 34 additions & 15 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
val runner = PythonRunner(func, bufferSize, reuse_worker)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
Expand All @@ -77,22 +77,30 @@ private[spark] case class PythonFunction(
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]])


object PythonRunner {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be private[spark].

def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
new PythonRunner(Seq(Seq(func)), bufferSize, reuse_worker, false, Seq(1))
}
}

/**
* A helper class to run Python UDFs in Spark.
* A helper class to run Python mapPartition/UDFs in Spark.
*/
private[spark] class PythonRunner(
funcs: Seq[PythonFunction],
funcs: Seq[Seq[PythonFunction]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type is a little strange, so do you mind adding a scaladoc comment to explain what the two levels of nesting correspond to?

bufferSize: Int,
reuse_worker: Boolean,
rowBased: Boolean)
isUDF: Boolean,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, do you mind adding scaldoc for these two new parameters?

numArgs: Seq[Int])
extends Logging {

// All the Python functions should have the same exec, version and envvars.
private val envVars = funcs.head.envVars
private val pythonExec = funcs.head.pythonExec
private val pythonVer = funcs.head.pythonVer
private val envVars = funcs.head.head.envVars
private val pythonExec = funcs.head.head.pythonExec
private val pythonVer = funcs.head.head.pythonVer

private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF
private val accumulator = funcs.head.head.accumulator // TODO: support accumulator in multiple UDF

def compute(
inputIterator: Iterator[_],
Expand Down Expand Up @@ -232,8 +240,8 @@ private[spark] class PythonRunner(

@volatile private var _exception: Exception = null

private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)
private val pythonIncludes = funcs.flatMap(_.flatMap(_.pythonIncludes.asScala)).toSet
private val broadcastVars = funcs.flatMap(_.flatMap(_.broadcastVars.asScala))

setDaemon(true)

Expand Down Expand Up @@ -284,11 +292,22 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(if (rowBased) 1 else 0)
dataOut.writeInt(funcs.length)
funcs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
if (isUDF) {
dataOut.writeInt(1)
dataOut.writeInt(funcs.length)
funcs.zip(numArgs).foreach { case (fs, numArg) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since correctness relies on funcs.length == numArgs.length, do you mind adding a require at the start of the constructor to enforce this?

dataOut.writeInt(numArg)
dataOut.writeInt(fs.length)
fs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
}
}
} else {
dataOut.writeInt(0)
val command = funcs.head.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
}
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,8 +1649,7 @@ def sort_array(col, asc=True):
# ---------------------------- User Defined Function ----------------------------------

def _wrap_function(sc, func, returnType):
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, returnType, ser)
command = (func, returnType)
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
Expand Down
11 changes: 10 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test_udf2(self):
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])

def test_chained_python_udf(self):
def test_chained_udf(self):
self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1)").collect()
self.assertEqual(row[0], 2)
Expand All @@ -314,6 +314,15 @@ def test_chained_python_udf(self):
[row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
self.assertEqual(row[0], 6)

def test_multiple_udfs(self):
[row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
self.assertEqual(tuple(row), (2, 4))
[row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
self.assertEqual(tuple(row), (4, 12))
self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
self.assertEqual(tuple(row), (6, 5))

def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
Expand Down
62 changes: 46 additions & 16 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, AutoBatchedSerializer
from pyspark import shuffle

pickleSer = PickleSerializer()
Expand Down Expand Up @@ -59,7 +59,48 @@ def read_command(serializer, file):

def chain(f, g):
"""chain two function together """
return lambda x: g(f(x))
return lambda *a: g(f(*a))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woah, didn't know that you could do varargs lambdas. Cool!



def wrap_udf(f, return_type):
return lambda *a: return_type.toInternal(f(*a))


def read_single_udf(pickleSer, infile):
num_arg = read_int(infile)
row_func = None
for i in range(read_int(infile)):
f, return_type = read_command(pickleSer, infile)
if row_func is None:
row_func = f
else:
row_func = chain(row_func, f)
# the last returnType will be the return type of UDF
return num_arg, wrap_udf(row_func, return_type)


def read_udfs(pickleSer, infile):
num_udfs = read_int(infile)
udfs = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like udfs holds (something, something, udf) triples. Mind adding a line-comment here to say what the first two components of the tuple correspond to?

offset = 0
for i in range(num_udfs):
num_arg, udf = read_single_udf(pickleSer, infile)
udfs.append((offset, offset + num_arg, udf))
offset += num_arg

if num_udfs == 1:
udf = udfs[0][2]

def mapper(args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I bet you could even do mapper = udf if you wanted to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't, input of mapper is a tuple, but udf is not

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it. Makes sense.

return (udf(*args),)
else:
def mapper(args):
return tuple(udf(*args[start:end]) for start, end, udf in udfs)

func = lambda _, it: map(mapper, it)
ser = AutoBatchedSerializer(PickleSerializer())
# profiling is not supported for UDF
return func, None, ser, ser


def main(infile, outfile):
Expand Down Expand Up @@ -107,21 +148,10 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)

_accumulatorRegistry.clear()
row_based = read_int(infile)
num_commands = read_int(infile)
if row_based:
profiler = None # profiling is not supported for UDF
row_func = None
for i in range(num_commands):
f, returnType, deserializer = read_command(pickleSer, infile)
if row_func is None:
row_func = f
else:
row_func = chain(row_func, f)
serializer = deserializer
func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
is_udf = read_int(infile)
if is_udf:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd maybe call this is_sql_udf just to make it clearer that this is part of PySpark SQL support, but I don't feel strongly about this.

func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
else:
assert num_commands == 1
func, profiler, deserializer, serializer = read_command(pickleSer, infile)

init_time = time.time()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
exchange.ShuffleExchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
case e @ python.EvaluatePython(udf, child, _) =>
python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case e @ python.EvaluatePython(udfs, child, _) =>
python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case _ => Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
* we drain the queue to find the original input row. Note that if the Python process is way too
* slow, this could lead to the queue growing unbounded and eventually run out of memory.
*/
case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends SparkPlan {

def children: Seq[SparkPlan] = child :: Nil
Expand Down Expand Up @@ -69,11 +69,15 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
// combine input with output from Python.
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()

val (pyFuncs, children) = collectFunctions(udf)
val (pyFuncs, children) = udfs.map(collectFunctions).unzip
val numArgs = children.map(_.length)
val resultType = StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))

val pickle = new Pickler
val currentRow = newMutableProjection(children, child.output)()
val fields = children.map(_.dataType)
// flatten all the arguments
val allChildren = children.flatMap(x => x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick clarification: if I have a function like select udf(x), udf2(x), udf3(x), udf4(x) from ..., we'll send the x column's value four times to PySpark? I know that we have a conceptually similar problem when we're evaluating multiple aggregates in parallel in JVM Spark SQL, but in that case I think we only project each column once and end up rebinding the references / offsets to reference the single copy.

My hunch is that this extra copy isn't a huge perf. issue compared to the slow multiple-Python-UDF evaluation strategy we were using before, so I think it's fine to leave this for now. If it does become a problem, we could optimize later.

val currentRow = newMutableProjection(allChildren, child.output)()
val fields = allChildren.map(_.dataType)
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)

// Input iterator to Python: input rows are grouped so we send them in batches to Python.
Expand All @@ -89,7 +93,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
val context = TaskContext.get()

// Output iterator for results from Python.
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true)
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, numArgs)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
Expand All @@ -101,7 +105,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
row(0) = EvaluatePython.fromJava(result, udf.dataType)
val row = EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
resultProj(joined(queue.poll(), row))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,28 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Evaluates a [[PythonUDF]], appending the result to the end of the input tuple.
* Evaluates a list of [[PythonUDF]], appending the result to the end of the input tuple.
*/
case class EvaluatePython(
udf: PythonUDF,
udfs: Seq[PythonUDF],
child: LogicalPlan,
resultAttribute: AttributeReference)
resultAttribute: Seq[AttributeReference])
extends logical.UnaryNode {

def output: Seq[Attribute] = child.output :+ resultAttribute
def output: Seq[Attribute] = child.output ++ resultAttribute

// References should not include the produced attribute.
override def references: AttributeSet = udf.references
override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references))
}


object EvaluatePython {
def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
def apply(udfs: Seq[PythonUDF], child: LogicalPlan): EvaluatePython = {
val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
AttributeReference(s"pythonUDF$i", u.dataType)()
}
new EvaluatePython(udfs, child, resultAttrs)
}

def takeAndServe(df: DataFrame, n: Int): Int = {
registerPicklers()
Expand Down
Loading