Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fast path for single UDF
  • Loading branch information
Davies Liu committed Mar 30, 2016
commit 8e6e5bc623adfa77e25ebf02712b10ca1a8a7a60
1 change: 1 addition & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def test_chained_udf(self):
self.assertEqual(row[0], 6)

def test_multiple_udfs(self):
self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
self.assertEqual(tuple(row), (2, 4))
[row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ def read_udfs(pickleSer, infile):
if num_udfs == 1:
udf = udfs[0][2]

# fast path for single UDF
def mapper(args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I bet you could even do mapper = udf if you wanted to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't, input of mapper is a tuple, but udf is not

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it. Makes sense.

return (udf(*args),)
return udf(*args)
else:
def mapper(args):
return tuple(udf(*args[start:end]) for start, end, udf in udfs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c

val (pyFuncs, children) = udfs.map(collectFunctions).unzip
val numArgs = children.map(_.length)
val resultType = StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))

val pickle = new Pickler
// flatten all the arguments
Expand All @@ -97,15 +96,26 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
val row = new GenericMutableRow(1)
val mutableRow = new GenericMutableRow(1)
val joined = new JoinedRow
val resultType = if (udfs.length == 1) {
udfs.head.dataType
} else {
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
}
val resultProj = UnsafeProjection.create(output, output)

outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
val row = EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
val row = if (udfs.length == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than evaluating this if condition for every row, could we lift this out of the map and perform it once while building the RDD DAG? i.e. assign the result of line 108 to a variable and have the if be the last return value of this block?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do this, you could reduce the scope of the mutableRow created up on line 99, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing evaluate Python UDF, I think this does not matter, JIT compiler could predict this branch pretty easy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

// fast path for single UDF
mutableRow(0) = EvaluatePython.fromJava(result, resultType)
mutableRow
} else {
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
}
resultProj(joined(queue.poll(), row))
}
}
Expand Down