Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Intial commit of using grouped agg pandas UDFs as window functions
  • Loading branch information
icexelloss committed Jun 12, 2018
commit 659e1dfc69180875d2a8936faea0c27fb6848953
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ private[spark] object PythonEvalType {
val SQL_SCALAR_PANDAS_UDF = 200
val SQL_GROUPED_MAP_PANDAS_UDF = 201
val SQL_GROUPED_AGG_PANDAS_UDF = 202
val SQL_WINDOW_AGG_PANDAS_UDF = 203

def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF"
case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF"
case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF"
case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
}
}

Expand Down
1 change: 1 addition & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class PythonEvalType(object):
SQL_SCALAR_PANDAS_UDF = 200
SQL_GROUPED_MAP_PANDAS_UDF = 201
SQL_GROUPED_AGG_PANDAS_UDF = 202
SQL_WINDOW_AGG_PANDAS_UDF = 203


def portable_hash(x):
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2405,7 +2405,6 @@ class PandasUDFType(object):

GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF


@since(1.3)
def udf(f=None, returnType=StringType()):
"""Creates a user defined function (UDF).
Expand Down
27 changes: 27 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5479,6 +5479,33 @@ def test_invalid_args(self):
'mixture.*aggregate function.*group aggregate pandas UDF'):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()

@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class WindowPandasUDFTests(ReusedSQLTestCase):
@property
def data(self):
from pyspark.sql.functions import array, explode, col, lit
return self.spark.range(10).toDF('id') \
.withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
.withColumn("v", explode(col('vs'))) \
.drop('vs') \
.withColumn('w', lit(1.0))

def test_simple(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max
from pyspark.sql.window import Window

df = self.data
w = Window.partitionBy('id').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
mean_udf = pandas_udf(lambda v: v.mean(), 'double', PandasUDFType.GROUPED_AGG)

result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w)).toPandas()
expected1 = df.withColumn('mean_v', mean(df['v']).over(w)).toPandas()

self.assertPandasEqual(expected1, result1)


if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:
Expand Down
16 changes: 15 additions & 1 deletion python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ def wrapped(*series):

return lambda *a: (wrapped(*a), arrow_return_type)

def wrap_window_agg_pandas_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)

def wrapped(*series):
import pandas as pd
import numpy as np
result = f(*series)
# This doesn't work with non primitive types
return pd.Series(np.repeat(result, len(series[0])))

return lambda *a: (wrapped(*a), arrow_return_type)

def read_single_udf(pickleSer, infile, eval_type):
num_arg = read_int(infile)
Expand All @@ -151,6 +162,8 @@ def read_single_udf(pickleSer, infile, eval_type):
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
return arg_offsets, wrap_window_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
return arg_offsets, wrap_udf(func, return_type)
else:
Expand Down Expand Up @@ -195,7 +208,8 @@ def read_udfs(pickleSer, infile, eval_type):

if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):
timezone = utf8_deserializer.loads(infile)
ser = ArrowStreamPandasSerializer(timezone)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
Expand Down Expand Up @@ -118,6 +119,8 @@ trait CheckAnalysis extends PredicateHelper {
e match {
case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction =>
w
case f: PythonUDF if PythonUDF.isWindowPandasUDF(f) =>
w
case _ =>
failAnalysis(s"Expression '$e' not supported within a window function.")
}
Expand Down Expand Up @@ -154,7 +157,7 @@ trait CheckAnalysis extends PredicateHelper {

case Aggregate(groupingExprs, aggregateExprs, child) =>
def isAggregateExpression(expr: Expression) = {
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr)
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
}

def checkValidAggregateExpression(expr: Expression): Unit = expr match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,15 @@ object PythonUDF {
e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType)
}

def isGroupAggPandasUDF(e: Expression): Boolean = {
def isGroupedAggPandasUDF(e: Expression): Boolean = {
e.isInstanceOf[PythonUDF] &&
e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
}

def isWindowPandasUDF(e: Expression): Boolean = {
e.isInstanceOf[PythonUDF] &&
e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent style.

}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.planning

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
Expand Down Expand Up @@ -215,7 +216,7 @@ object PhysicalAggregation {
case agg: AggregateExpression
if !equivalentAggregateExpressions.addExpr(agg) => agg
case udf: PythonUDF
if PythonUDF.isGroupAggPandasUDF(udf) &&
if PythonUDF.isGroupedAggPandasUDF(udf) &&
!equivalentAggregateExpressions.addExpr(udf) => udf
}
}
Expand Down Expand Up @@ -245,7 +246,7 @@ object PhysicalAggregation {
equivalentAggregateExpressions.getEquivalentExprs(ae).headOption
.getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute
// Similar to AggregateExpression
case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) =>
case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) =>
equivalentAggregateExpressions.getEquivalentExprs(ue).headOption
.getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute
case expression =>
Expand All @@ -268,3 +269,43 @@ object PhysicalAggregation {
case _ => None
}
}

object PhysicalWindow {

sealed trait WindowFunctionType

// Whether the window function is a Scalar window function or a Python window function.
// We don't current support mixes of these two types so we use a single enum to present all
// window functions in the window expression.
case object Scala extends WindowFunctionType
case object Python extends WindowFunctionType

// windowFunctionType, windowExpression, partitionSpec, orderSpec, resultExpression, child
type ReturnType =
(WindowFunctionType,
Seq[WindowExpression], Seq[Expression], Seq[SortOrder], Seq[NamedExpression], LogicalPlan)

def unapply(a: Any): Option[ReturnType] = a match {
case logical.Window(windowExpressions, partitionSpec, orderSpec, child) =>

val newWindowExpressions = windowExpressions.flatMap { expr =>
expr.collect {
case we: WindowExpression => we
}
}

val windowFunctionType = newWindowExpressions.map(_.windowFunction) match {
case wfs: Seq[Expression] if wfs.forall(PythonUDF.isWindowPandasUDF) => Python
case wfs: Seq[Expression] if !wfs.exists(PythonUDF.isWindowPandasUDF) => Scala
case _ => throw new AnalysisException(
"Cannot use a mixture of window function and window Pandas UDF")
}

val resultExpressions = windowExpressions.map(_.toAttribute)

Some((windowFunctionType,
newWindowExpressions, partitionSpec, orderSpec, resultExpressions, child))

case _ => None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class SparkPlanner(
DataSourceStrategy(conf) ::
SpecialLimits ::
Aggregation ::
Window ::
JoinSelection ::
InMemoryScans ::
BasicOperators :: Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableS
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.python.WindowInPandasExec
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -327,7 +328,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case PhysicalAggregation(
namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) =>

if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) {
if (aggregateExpressions.exists(PythonUDF.isGroupedAggPandasUDF)) {
throw new AnalysisException(
"Streaming aggregation doesn't support group aggregate pandas UDF")
}
Expand Down Expand Up @@ -428,6 +429,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

object Window extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalWindow(
PhysicalWindow.Scala, windowExprs, partitionSpec, orderSpec, resultExprs, child) =>
execution.window.WindowExec(
windowExprs, partitionSpec, orderSpec, resultExprs, planLater(child)) :: Nil

case PhysicalWindow(
PhysicalWindow.Python, windowExprs, partitionSpec, orderSpec, resultExprs, child) =>
WindowInPandasExec(
windowExprs, partitionSpec, orderSpec, resultExprs, planLater(child)) :: Nil
case _ => Nil
}
}

protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)

object InMemoryScans extends Strategy {
Expand Down Expand Up @@ -548,8 +564,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil
case e @ logical.Expand(_, _, child) =>
execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
*/
private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
e.isInstanceOf[AggregateExpression] ||
PythonUDF.isGroupAggPandasUDF(e) ||
PythonUDF.isGroupedAggPandasUDF(e) ||
agg.groupingExpressions.exists(_.semanticEquals(e))
}

Expand Down
Loading