Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Clean up; Add more tests
  • Loading branch information
icexelloss committed Jun 12, 2018
commit abdfd9eb75e89c42a51eb9f11da2283766567b4c
37 changes: 30 additions & 7 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5454,6 +5454,15 @@ def test_retain_group_columns(self):
expected1 = df.groupby(df.id).agg(sum(df.v))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())

def test_array_type(self):
Copy link
Contributor Author

@icexelloss icexelloss Apr 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated, but I figured its shouldn't hurt to add an array test in GroupedAggPandasUDFTests..

from pyspark.sql.functions import pandas_udf, PandasUDFType

df = self.data

array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2'))
self.assertEquals(result1.first()['v2'], [1.0, 2.0])

def test_invalid_args(self):
from pyspark.sql.functions import mean

Expand Down Expand Up @@ -5556,9 +5565,6 @@ def test_multiple_udfs(self):
.withColumn('max_v', max(df['v']).over(w)) \
.withColumn('min_w', min(df['w']).over(w))

result1.explain(True)
expected1.explain(True)

self.assertPandasEqual(expected1.toPandas(), result1.toPandas())

def test_replace_existing(self):
Expand Down Expand Up @@ -5646,29 +5652,46 @@ def test_mixed_sql_and_udf(self):
expected3 = expected1

# Test mixing sql window function and udf
result4 = df.withColumn('max_v', max_udf(df['v']).over(w)).withColumn('rank', rank().over(ow))
expected4 = df.withColumn('max_v', max(df['v']).over(w)).withColumn('rank', rank().over(ow))
result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
.withColumn('rank', rank().over(ow))
expected4 = df.withColumn('max_v', max(df['v']).over(w)) \
.withColumn('rank', rank().over(ow))

self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())

def test_array_type(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType

df = self.data
w = self.unbounded_window

array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
result1 = df.withColumn('v2', array_udf(df['v']).over(w))
self.assertEquals(result1.first()['v2'], [1.0, 2.0])

def test_invalid_args(self):
from pyspark.sql.functions import mean, pandas_udf, PandasUDFType

df = self.data
w = self.unbounded_window

ow = self.ordered_window
mean_udf = self.pandas_agg_mean_udf

with QuietTest(self.sc):
with self.assertRaisesRegexp(
AnalysisException,
'.*not supported within a window function'):
'.*does not have any WindowFunction'):
foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
df.withColumn('v2', foo_udf(df['v']).over(w))

with QuietTest(self.sc):
with self.assertRaisesRegexp(
AnalysisException,
'Only unbounded window frame is supported with Python UDFs.'):
df.withColumn('mean_v', mean_udf(df['v']).over(ow))


if __name__ == "__main__":
Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,8 @@ def wrap_window_agg_pandas_udf(f, return_type):

def wrapped(*series):
import pandas as pd
import numpy as np
result = f(*series)
# This doesn't work with non primitive types
return pd.Series(np.repeat(result, len(series[0])))
return pd.Series([result]).repeat(len(series[0]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering why this needs to be repeated to the length of the series and grouped agg doesn't?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

window aggregation results are broadcasted to each input row and therefore we repeat the value here to match the input rows.

Copy link
Member

@HyukjinKwon HyukjinKwon May 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave a short comment while we are here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments to describe the function

Copy link
Member

@HyukjinKwon HyukjinKwon May 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this place is the only place where it's diverted (by repeating?); therefore, needs Windows specific attribute to distinguish grouped agg vs windows agg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - I tried to do this on the Java side but it's tricky and complicated to merging the input row and output of udf if they are not 1-1 mapping. So I ended up doing this..


return lambda *a: (wrapped(*a), arrow_return_type)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,14 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis("An offset window function can only be evaluated in an ordered " +
s"row-based window frame with a single offset: $w")

case w @ WindowExpression(_: PythonUDF,
WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation :-)

if !frame.isUnbounded =>
failAnalysis(s"Only unbounded window frame is supported with Python UDFs.")

case w @ WindowExpression(e, s) =>
// Only allow window functions with an aggregate expression or an offset window
// function.
// function or a Pandas window UDF.
e match {
case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction =>
w
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,7 @@ object WindowFunctionType {

def functionType(windowExpression: NamedExpression): Option[WindowFunctionType] = {
windowExpression.collectFirst {
case _: WindowFunction => SQL
case _: AggregateFunction => SQL
case _: WindowFunction | _: AggregateFunction => SQL
case udf: PythonUDF if PythonUDF.isWindowPandasUDF(udf) => Python
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,17 @@ case class WindowInPandasExec(
}

val inputProj = UnsafeProjection.create(allInputs, child.output)
val pythonInput = grouped.map{ case (k, rows) =>
val pythonInput = grouped.map { case (k, rows) =>
rows.map { row =>
queue.add(row.asInstanceOf[UnsafeRow])
inputProj(row)
}
}

val pythonEvalType = udfExpressions.head.evalType match {
case PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF =>
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF
}

val windowFunctionResult = new ArrowPythonRunner(
pyFuncs, bufferSize, reuseWorker,
pythonEvalType, argOffsets, windowInputSchema,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
argOffsets, windowInputSchema,
sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(pythonInput, context.partitionId(), context)

Expand Down