Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Address comments
  • Loading branch information
icexelloss committed Jun 12, 2018
commit 328b2c4e09502a66939d47d6967ceea7ceab6c8c
3 changes: 2 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2613,7 +2613,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don't have PandasUDFType.WINDOW_AGG and a pandas udf defined as PandasUDFType.GROUPED_AGG can be both used with groupby and Window?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. The idea is that the producer of the UDF can produce a grouped agg udf, such as weighted mean, and the consumer can use the UDF in both groupby and window, similar to how SQL aggregation function work.

... def mean_udf(v):
... return v.mean()
>>> w = Window.partitionBy('id')
>>> w = Window.partitionBy('id') \\
... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
>>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP
+---+----+------+
| id| v|mean_v|
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5577,7 +5577,7 @@ def test_multiple_udfs(self):

result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \
.withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \
.withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) \
.withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w))

expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \
.withColumn('max_v', max(df['v']).over(w)) \
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def wrapped(*series):


def wrap_window_agg_pandas_udf(f, return_type):
# This is similar to grouped_agg_pandas_udf, the only difference
# is that window_agg_pandas_udf needs to repeat the return value
# to match window length, where grouped_agg_pandas_udf just returns
# the scalar value.
arrow_return_type = to_arrow_type(return_type)

def wrapped(*series):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ trait CheckAnalysis extends PredicateHelper {
case _ @ WindowExpression(_: PythonUDF,
WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame))
if !frame.isUnbounded =>
failAnalysis(s"Only unbounded window frame is supported with Pandas UDFs.")
failAnalysis("Only unbounded window frame is supported with Pandas UDFs.")

case w @ WindowExpression(e, s) =>
// Only allow window functions with an aggregate expression or an offset window
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ object PythonUDF {
e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
}

def isWindowPandasUDF(e: Expression): Boolean = {
e.isInstanceOf[PythonUDF] &&
e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
}
// This is currently same as GroupedAggPandasUDF, but we might support new types in the future,
// e.g, N -> N transform.
def isWindowPandasUDF(e: Expression): Boolean = isGroupedAggPandasUDF(e)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ object WindowFunctionType {
case udf: PythonUDF if PythonUDF.isWindowPandasUDF(udf) => Python
}

// Normally a window expression would either have either a SQL window function, a SQL
// Normally a window expression would either have a SQL window function, a SQL
// aggregate function or a python window UDF. However, sometimes the optimizer will replace
// the window function if the value of the window function can be predetermined.
// For example, for query:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ object PhysicalAggregation {
*/
object PhysicalWindow {
// windowFunctionType, windowExpression, partitionSpec, orderSpec, child
type ReturnType =
private type ReturnType =
(WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan)

def unapply(a: Any): Option[ReturnType] = a match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
WindowFunctionType.Python, windowExprs, partitionSpec, orderSpec, child) =>
execution.python.WindowInPandasExec(
windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit: I would add a newline below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added


case _ => Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ case class WindowInPandasExec(
* @return the final resulting projection.
*/
private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
val references = expressions.zipWithIndex.map{ case (e, i) =>
val references = expressions.zipWithIndex.map { case (e, i) =>
// Results of window expressions will be on the right side of child's output
BoundReference(child.output.size + i, e.dataType, e.nullable)
}
Expand Down