-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch #12057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f6b7373
8e6e5bc
8dc1adf
dd71ba9
8597bba
72a5ec0
876f9f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,7 +59,7 @@ private[spark] class PythonRDD( | |
| val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) | ||
|
|
||
| override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { | ||
| val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false) | ||
| val runner = PythonRunner(func, bufferSize, reuse_worker) | ||
| runner.compute(firstParent.iterator(split, context), split.index, context) | ||
| } | ||
| } | ||
|
|
@@ -77,22 +77,30 @@ private[spark] case class PythonFunction( | |
| broadcastVars: JList[Broadcast[PythonBroadcast]], | ||
| accumulator: Accumulator[JList[Array[Byte]]]) | ||
|
|
||
|
|
||
| object PythonRunner { | ||
| def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { | ||
| new PythonRunner(Seq(Seq(func)), bufferSize, reuse_worker, false, Seq(1)) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A helper class to run Python UDFs in Spark. | ||
| * A helper class to run Python mapPartition/UDFs in Spark. | ||
| */ | ||
| private[spark] class PythonRunner( | ||
| funcs: Seq[PythonFunction], | ||
| funcs: Seq[Seq[PythonFunction]], | ||
|
||
| bufferSize: Int, | ||
| reuse_worker: Boolean, | ||
| rowBased: Boolean) | ||
| isUDF: Boolean, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly, do you mind adding scaldoc for these two new parameters? |
||
| numArgs: Seq[Int]) | ||
| extends Logging { | ||
|
|
||
| // All the Python functions should have the same exec, version and envvars. | ||
| private val envVars = funcs.head.envVars | ||
| private val pythonExec = funcs.head.pythonExec | ||
| private val pythonVer = funcs.head.pythonVer | ||
| private val envVars = funcs.head.head.envVars | ||
| private val pythonExec = funcs.head.head.pythonExec | ||
| private val pythonVer = funcs.head.head.pythonVer | ||
|
|
||
| private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF | ||
| private val accumulator = funcs.head.head.accumulator // TODO: support accumulator in multiple UDF | ||
|
|
||
| def compute( | ||
| inputIterator: Iterator[_], | ||
|
|
@@ -232,8 +240,8 @@ private[spark] class PythonRunner( | |
|
|
||
| @volatile private var _exception: Exception = null | ||
|
|
||
| private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet | ||
| private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala) | ||
| private val pythonIncludes = funcs.flatMap(_.flatMap(_.pythonIncludes.asScala)).toSet | ||
| private val broadcastVars = funcs.flatMap(_.flatMap(_.broadcastVars.asScala)) | ||
|
|
||
| setDaemon(true) | ||
|
|
||
|
|
@@ -284,11 +292,22 @@ private[spark] class PythonRunner( | |
| } | ||
| dataOut.flush() | ||
| // Serialized command: | ||
| dataOut.writeInt(if (rowBased) 1 else 0) | ||
| dataOut.writeInt(funcs.length) | ||
| funcs.foreach { f => | ||
| dataOut.writeInt(f.command.length) | ||
| dataOut.write(f.command) | ||
| if (isUDF) { | ||
| dataOut.writeInt(1) | ||
| dataOut.writeInt(funcs.length) | ||
| funcs.zip(numArgs).foreach { case (fs, numArg) => | ||
|
||
| dataOut.writeInt(numArg) | ||
| dataOut.writeInt(fs.length) | ||
| fs.foreach { f => | ||
| dataOut.writeInt(f.command.length) | ||
| dataOut.write(f.command) | ||
| } | ||
| } | ||
| } else { | ||
| dataOut.writeInt(0) | ||
| val command = funcs.head.head.command | ||
| dataOut.writeInt(command.length) | ||
| dataOut.write(command) | ||
| } | ||
| // Data values | ||
| PythonRDD.writeIteratorToStream(inputIterator, dataOut) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,7 +29,7 @@ | |
| from pyspark.broadcast import Broadcast, _broadcastRegistry | ||
| from pyspark.files import SparkFiles | ||
| from pyspark.serializers import write_with_length, write_int, read_long, \ | ||
| write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer | ||
| write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, AutoBatchedSerializer | ||
| from pyspark import shuffle | ||
|
|
||
| pickleSer = PickleSerializer() | ||
|
|
@@ -59,7 +59,48 @@ def read_command(serializer, file): | |
|
|
||
| def chain(f, g): | ||
| """chain two function together """ | ||
| return lambda x: g(f(x)) | ||
| return lambda *a: g(f(*a)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Woah, didn't know that you could do varargs lambdas. Cool! |
||
|
|
||
|
|
||
| def wrap_udf(f, return_type): | ||
| return lambda *a: return_type.toInternal(f(*a)) | ||
|
|
||
|
|
||
| def read_single_udf(pickleSer, infile): | ||
| num_arg = read_int(infile) | ||
| row_func = None | ||
| for i in range(read_int(infile)): | ||
| f, return_type = read_command(pickleSer, infile) | ||
| if row_func is None: | ||
| row_func = f | ||
| else: | ||
| row_func = chain(row_func, f) | ||
| # the last returnType will be the return type of UDF | ||
| return num_arg, wrap_udf(row_func, return_type) | ||
|
|
||
|
|
||
| def read_udfs(pickleSer, infile): | ||
| num_udfs = read_int(infile) | ||
| udfs = [] | ||
|
||
| offset = 0 | ||
| for i in range(num_udfs): | ||
| num_arg, udf = read_single_udf(pickleSer, infile) | ||
| udfs.append((offset, offset + num_arg, udf)) | ||
| offset += num_arg | ||
|
|
||
| if num_udfs == 1: | ||
| udf = udfs[0][2] | ||
|
|
||
| def mapper(args): | ||
|
||
| return (udf(*args),) | ||
| else: | ||
| def mapper(args): | ||
| return tuple(udf(*args[start:end]) for start, end, udf in udfs) | ||
|
|
||
| func = lambda _, it: map(mapper, it) | ||
| ser = AutoBatchedSerializer(PickleSerializer()) | ||
| # profiling is not supported for UDF | ||
| return func, None, ser, ser | ||
|
|
||
|
|
||
| def main(infile, outfile): | ||
|
|
@@ -107,21 +148,10 @@ def main(infile, outfile): | |
| _broadcastRegistry.pop(bid) | ||
|
|
||
| _accumulatorRegistry.clear() | ||
| row_based = read_int(infile) | ||
| num_commands = read_int(infile) | ||
| if row_based: | ||
| profiler = None # profiling is not supported for UDF | ||
| row_func = None | ||
| for i in range(num_commands): | ||
| f, returnType, deserializer = read_command(pickleSer, infile) | ||
| if row_func is None: | ||
| row_func = f | ||
| else: | ||
| row_func = chain(row_func, f) | ||
| serializer = deserializer | ||
| func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) | ||
| is_udf = read_int(infile) | ||
| if is_udf: | ||
|
||
| func, profiler, deserializer, serializer = read_udfs(pickleSer, infile) | ||
| else: | ||
| assert num_commands == 1 | ||
| func, profiler, deserializer, serializer = read_command(pickleSer, infile) | ||
|
|
||
| init_time = time.time() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,7 +40,7 @@ import org.apache.spark.sql.types.{StructField, StructType} | |
| * we drain the queue to find the original input row. Note that if the Python process is way too | ||
| * slow, this could lead to the queue growing unbounded and eventually run out of memory. | ||
| */ | ||
| case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) | ||
| case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) | ||
| extends SparkPlan { | ||
|
|
||
| def children: Seq[SparkPlan] = child :: Nil | ||
|
|
@@ -69,11 +69,15 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: | |
| // combine input with output from Python. | ||
| val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() | ||
|
|
||
| val (pyFuncs, children) = collectFunctions(udf) | ||
| val (pyFuncs, children) = udfs.map(collectFunctions).unzip | ||
| val numArgs = children.map(_.length) | ||
| val resultType = StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) | ||
|
|
||
| val pickle = new Pickler | ||
| val currentRow = newMutableProjection(children, child.output)() | ||
| val fields = children.map(_.dataType) | ||
| // flatten all the arguments | ||
| val allChildren = children.flatMap(x => x) | ||
|
||
| val currentRow = newMutableProjection(allChildren, child.output)() | ||
| val fields = allChildren.map(_.dataType) | ||
| val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) | ||
|
|
||
| // Input iterator to Python: input rows are grouped so we send them in batches to Python. | ||
|
|
@@ -89,7 +93,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: | |
| val context = TaskContext.get() | ||
|
|
||
| // Output iterator for results from Python. | ||
| val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true) | ||
| val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, numArgs) | ||
| .compute(inputIterator, context.partitionId(), context) | ||
|
|
||
| val unpickle = new Unpickler | ||
|
|
@@ -101,7 +105,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: | |
| val unpickledBatch = unpickle.loads(pickedResult) | ||
| unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala | ||
| }.map { result => | ||
| row(0) = EvaluatePython.fromJava(result, udf.dataType) | ||
| val row = EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] | ||
| resultProj(joined(queue.poll(), row)) | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be
private[spark].