-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch #12057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f6b7373
8e6e5bc
8dc1adf
dd71ba9
8597bba
72a5ec0
876f9f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -77,30 +77,42 @@ private[spark] case class PythonFunction( | |
| broadcastVars: JList[Broadcast[PythonBroadcast]], | ||
| accumulator: Accumulator[JList[Array[Byte]]]) | ||
|
|
||
| /** | ||
| * A wrapper for chained Python functions (from bottom to top). | ||
| * @param funcs | ||
| */ | ||
| private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) | ||
|
|
||
| object PythonRunner { | ||
| private[spark] object PythonRunner { | ||
| def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { | ||
| new PythonRunner(Seq(Seq(func)), bufferSize, reuse_worker, false, Seq(1)) | ||
| new PythonRunner( | ||
| Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Seq(Seq(0))) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A helper class to run Python mapPartition/UDFs in Spark. | ||
| * | ||
| * funcs is a list of independent Python functions, each one of them is a list of chained Python | ||
| * functions (from bottom to top). | ||
| */ | ||
| private[spark] class PythonRunner( | ||
| funcs: Seq[Seq[PythonFunction]], | ||
| funcs: Seq[ChainedPythonFunctions], | ||
| bufferSize: Int, | ||
| reuse_worker: Boolean, | ||
| isUDF: Boolean, | ||
| numArgs: Seq[Int]) | ||
| argOffsets: Seq[Seq[Int]]) | ||
| extends Logging { | ||
|
|
||
| require(funcs.length == argOffsets.length, "numArgs should have the same length as funcs") | ||
|
||
|
|
||
| // All the Python functions should have the same exec, version and envvars. | ||
| private val envVars = funcs.head.head.envVars | ||
| private val pythonExec = funcs.head.head.pythonExec | ||
| private val pythonVer = funcs.head.head.pythonVer | ||
| private val envVars = funcs.head.funcs.head.envVars | ||
| private val pythonExec = funcs.head.funcs.head.pythonExec | ||
| private val pythonVer = funcs.head.funcs.head.pythonVer | ||
|
|
||
| private val accumulator = funcs.head.head.accumulator // TODO: support accumulator in multiple UDF | ||
| // TODO: support accumulator in multiple UDF | ||
| private val accumulator = funcs.head.funcs.head.accumulator | ||
|
|
||
| def compute( | ||
| inputIterator: Iterator[_], | ||
|
|
@@ -240,8 +252,8 @@ private[spark] class PythonRunner( | |
|
|
||
| @volatile private var _exception: Exception = null | ||
|
|
||
| private val pythonIncludes = funcs.flatMap(_.flatMap(_.pythonIncludes.asScala)).toSet | ||
| private val broadcastVars = funcs.flatMap(_.flatMap(_.broadcastVars.asScala)) | ||
| private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet | ||
| private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) | ||
|
|
||
| setDaemon(true) | ||
|
|
||
|
|
@@ -295,17 +307,20 @@ private[spark] class PythonRunner( | |
| if (isUDF) { | ||
| dataOut.writeInt(1) | ||
| dataOut.writeInt(funcs.length) | ||
| funcs.zip(numArgs).foreach { case (fs, numArg) => | ||
| dataOut.writeInt(numArg) | ||
| dataOut.writeInt(fs.length) | ||
| fs.foreach { f => | ||
| funcs.zip(argOffsets).foreach { case (chained, offsets) => | ||
| dataOut.writeInt(offsets.length) | ||
| offsets.foreach { offset => | ||
| dataOut.writeInt(offset) | ||
| } | ||
| dataOut.writeInt(chained.funcs.length) | ||
| chained.funcs.foreach { f => | ||
| dataOut.writeInt(f.command.length) | ||
| dataOut.write(f.command) | ||
| } | ||
| } | ||
| } else { | ||
| dataOut.writeInt(0) | ||
| val command = funcs.head.head.command | ||
| val command = funcs.head.funcs.head.command | ||
| dataOut.writeInt(command.length) | ||
| dataOut.write(command) | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,11 +63,13 @@ def chain(f, g): | |
|
|
||
|
|
||
| def wrap_udf(f, return_type): | ||
| return lambda *a: return_type.toInternal(f(*a)) | ||
| toInternal = return_type.toInternal | ||
| return lambda *a: toInternal(f(*a)) | ||
|
|
||
|
|
||
| def read_single_udf(pickleSer, infile): | ||
| num_arg = read_int(infile) | ||
| arg_offsets = [read_int(infile) for i in range(num_arg)] | ||
| row_func = None | ||
| for i in range(read_int(infile)): | ||
| f, return_type = read_command(pickleSer, infile) | ||
|
|
@@ -76,27 +78,27 @@ def read_single_udf(pickleSer, infile): | |
| else: | ||
| row_func = chain(row_func, f) | ||
| # the last returnType will be the return type of UDF | ||
| return num_arg, wrap_udf(row_func, return_type) | ||
| return arg_offsets, wrap_udf(row_func, return_type) | ||
|
|
||
|
|
||
| def read_udfs(pickleSer, infile): | ||
| num_udfs = read_int(infile) | ||
| udfs = [] | ||
| offset = 0 | ||
| for i in range(num_udfs): | ||
| num_arg, udf = read_single_udf(pickleSer, infile) | ||
| udfs.append((offset, offset + num_arg, udf)) | ||
| offset += num_arg | ||
|
|
||
| if num_udfs == 1: | ||
| udf = udfs[0][2] | ||
|
|
||
| # fast path for single UDF | ||
| def mapper(args): | ||
| return udf(*args) | ||
| _, udf = read_single_udf(pickleSer, infile) | ||
| mapper = lambda a: udf(*a) | ||
| else: | ||
| def mapper(args): | ||
| return tuple(udf(*args[start:end]) for start, end, udf in udfs) | ||
| udfs = {} | ||
| call_udf = [] | ||
| for i in range(num_udfs): | ||
| arg_offsets, udf = read_single_udf(pickleSer, infile) | ||
| udfs['f%d' % i] = udf | ||
| args = ["a[%d]" % o for o in arg_offsets] | ||
| call_udf.append("f%d(%s)" % (i, ", ".join(args))) | ||
| # Create function like this: | ||
| # lambda a: (f0(a0), f1(a1, a2), f2(a3)) | ||
| mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clever! This is a neat trick. |
||
| mapper = eval(mapper_str, udfs) | ||
|
|
||
| func = lambda _, it: map(mapper, it) | ||
| ser = AutoBatchedSerializer(PickleSerializer()) | ||
|
|
@@ -149,8 +151,8 @@ def main(infile, outfile): | |
| _broadcastRegistry.pop(bid) | ||
|
|
||
| _accumulatorRegistry.clear() | ||
| is_udf = read_int(infile) | ||
| if is_udf: | ||
| is_sql_udf = read_int(infile) | ||
| if is_sql_udf: | ||
| func, profiler, deserializer, serializer = read_udfs(pickleSer, infile) | ||
| else: | ||
| func, profiler, deserializer, serializer = read_command(pickleSer, infile) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,16 +18,17 @@ | |
| package org.apache.spark.sql.execution.python | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import net.razorvine.pickle.{Pickler, Unpickler} | ||
|
|
||
| import org.apache.spark.TaskContext | ||
| import org.apache.spark.api.python.{PythonFunction, PythonRunner} | ||
| import org.apache.spark.api.python.{ChainedPythonFunctions, PythonFunction, PythonRunner} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.execution.SparkPlan | ||
| import org.apache.spark.sql.types.{StructField, StructType} | ||
| import org.apache.spark.sql.types.{DataType, StructField, StructType} | ||
|
|
||
|
|
||
| /** | ||
|
|
@@ -45,15 +46,15 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c | |
|
|
||
| def children: Seq[SparkPlan] = child :: Nil | ||
|
|
||
| private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = { | ||
| private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { | ||
| udf.children match { | ||
| case Seq(u: PythonUDF) => | ||
| val (fs, children) = collectFunctions(u) | ||
| (fs ++ Seq(udf.func), children) | ||
| val (chained, children) = collectFunctions(u) | ||
| (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) | ||
| case children => | ||
| // There should not be any other UDFs, or the children can't be evaluated directly. | ||
| assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) | ||
| (Seq(udf.func), udf.children) | ||
| (ChainedPythonFunctions(Seq(udf.func)), udf.children) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -69,30 +70,48 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c | |
| // combine input with output from Python. | ||
| val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() | ||
|
|
||
| val (pyFuncs, children) = udfs.map(collectFunctions).unzip | ||
| val numArgs = children.map(_.length) | ||
| val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip | ||
|
|
||
| val pickle = new Pickler | ||
| // Most of the inputs are primitives, do not use memo for better performance | ||
| val pickle = new Pickler(false) | ||
| // flatten all the arguments | ||
| val allChildren = children.flatMap(x => x) | ||
| val currentRow = newMutableProjection(allChildren, child.output)() | ||
| val fields = allChildren.map(_.dataType) | ||
| val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) | ||
| val allInputs = new ArrayBuffer[Expression] | ||
| val dataTypes = new ArrayBuffer[DataType] | ||
| val argOffsets = inputs.map { input => | ||
| input.map { e => | ||
| if (allInputs.exists(_.semanticEquals(e))) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the worst-case this loop is N^2, but N is probably pretty small so it probably doesn't matter compared to other perf. issues impacting Python UDFs.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. |
||
| allInputs.indexWhere(_.semanticEquals(e)) | ||
| } else { | ||
| allInputs += e | ||
| dataTypes += e.dataType | ||
| allInputs.length - 1 | ||
| } | ||
| } | ||
| } | ||
| val projection = newMutableProjection(allInputs, child.output)() | ||
|
|
||
| // Input iterator to Python: input rows are grouped so we send them in batches to Python. | ||
| // For each row, add it to the queue. | ||
| val inputIterator = iter.grouped(100).map { inputRows => | ||
| val toBePickled = inputRows.map { row => | ||
| queue.add(row) | ||
| EvaluatePython.toJava(currentRow(row), schema) | ||
| val inputIterator = iter.grouped(1024).map { inputRows => | ||
| val toBePickled = inputRows.map { inputRow => | ||
| queue.add(inputRow) | ||
| val row = projection(inputRow) | ||
| val fields = new Array[Any](row.numFields) | ||
| var i = 0 | ||
| while (i < row.numFields) { | ||
| val dt = dataTypes(i) | ||
| fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) | ||
| i += 1 | ||
| } | ||
| fields | ||
| }.toArray | ||
| pickle.dumps(toBePickled) | ||
| } | ||
|
|
||
| val context = TaskContext.get() | ||
|
|
||
| // Output iterator for results from Python. | ||
| val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, numArgs) | ||
| val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets) | ||
| .compute(inputIterator, context.partitionId(), context) | ||
|
|
||
| val unpickle = new Unpickler | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, do you mind adding scaldoc for these two new parameters?