Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions python/pyspark/sql/tests/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,23 @@

import unittest

from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from pyspark.sql.types import *
from pyspark.sql.utils import ParseException
from pyspark.rdd import PythonEvalType
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest

from py4j.protocol import Py4JJavaError


@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message)
class PandasUDFTests(ReusedSQLTestCase):

def test_pandas_udf_basic(self):
from pyspark.rdd import PythonEvalType
from pyspark.sql.functions import pandas_udf, PandasUDFType

udf = pandas_udf(lambda x: x, DoubleType())
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
Expand Down Expand Up @@ -65,10 +66,6 @@ def test_pandas_udf_basic(self):
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

def test_pandas_udf_decorator(self):
from pyspark.rdd import PythonEvalType
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import StructType, StructField, DoubleType

@pandas_udf(DoubleType())
def foo(x):
return x
Expand Down Expand Up @@ -114,8 +111,6 @@ def foo(x):
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

def test_udf_wrong_arg(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType

with QuietTest(self.sc):
with self.assertRaises(ParseException):
@pandas_udf('blah')
Expand Down Expand Up @@ -151,9 +146,6 @@ def foo(k, v, w):
return k

def test_stopiteration_in_udf(self):
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

Expand Down
39 changes: 3 additions & 36 deletions python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

import unittest

from pyspark.rdd import PythonEvalType
from pyspark.sql.functions import array, explode, col, lit, mean, sum, \
udf, pandas_udf, PandasUDFType
from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
Expand All @@ -31,7 +34,6 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):

@property
def data(self):
from pyspark.sql.functions import array, explode, col, lit
return self.spark.range(10).toDF('id') \
.withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
.withColumn("v", explode(col('vs'))) \
Expand All @@ -40,8 +42,6 @@ def data(self):

@property
def python_plus_one(self):
from pyspark.sql.functions import udf

@udf('double')
def plus_one(v):
assert isinstance(v, (int, float))
Expand All @@ -51,7 +51,6 @@ def plus_one(v):
@property
def pandas_scalar_plus_two(self):
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType

@pandas_udf('double', PandasUDFType.SCALAR)
def plus_two(v):
Expand All @@ -61,17 +60,13 @@ def plus_two(v):

@property
def pandas_agg_mean_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType

@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def avg(v):
return v.mean()
return avg

@property
def pandas_agg_sum_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType

@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def sum(v):
return v.sum()
Expand All @@ -80,16 +75,13 @@ def sum(v):
@property
def pandas_agg_weighted_mean_udf(self):
import numpy as np
from pyspark.sql.functions import pandas_udf, PandasUDFType

@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def weighted_mean(v, w):
return np.average(v, weights=w)
return weighted_mean

def test_manual(self):
from pyspark.sql.functions import pandas_udf, array

df = self.data
sum_udf = self.pandas_agg_sum_udf
mean_udf = self.pandas_agg_mean_udf
Expand Down Expand Up @@ -118,8 +110,6 @@ def test_manual(self):
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())

def test_basic(self):
from pyspark.sql.functions import col, lit, mean

df = self.data
weighted_mean_udf = self.pandas_agg_weighted_mean_udf

Expand Down Expand Up @@ -150,9 +140,6 @@ def test_basic(self):
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())

def test_unsupported_types(self):
from pyspark.sql.types import DoubleType, MapType
from pyspark.sql.functions import pandas_udf, PandasUDFType

with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
pandas_udf(
Expand All @@ -173,8 +160,6 @@ def mean_and_std_udf(v):
return {v.mean(): v.std()}

def test_alias(self):
from pyspark.sql.functions import mean

df = self.data
mean_udf = self.pandas_agg_mean_udf

Expand All @@ -187,8 +172,6 @@ def test_mixed_sql(self):
"""
Test mixing group aggregate pandas UDF with sql expression.
"""
from pyspark.sql.functions import sum

df = self.data
sum_udf = self.pandas_agg_sum_udf

Expand Down Expand Up @@ -225,8 +208,6 @@ def test_mixed_udfs(self):
"""
Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF.
"""
from pyspark.sql.functions import sum

df = self.data
plus_one = self.python_plus_one
plus_two = self.pandas_scalar_plus_two
Expand Down Expand Up @@ -292,8 +273,6 @@ def test_multiple_udfs(self):
"""
Test multiple group aggregate pandas UDFs in one agg function.
"""
from pyspark.sql.functions import sum, mean

df = self.data
mean_udf = self.pandas_agg_mean_udf
sum_udf = self.pandas_agg_sum_udf
Expand All @@ -315,8 +294,6 @@ def test_multiple_udfs(self):
self.assertPandasEqual(expected1, result1)

def test_complex_groupby(self):
from pyspark.sql.functions import sum

df = self.data
sum_udf = self.pandas_agg_sum_udf
plus_one = self.python_plus_one
Expand Down Expand Up @@ -359,8 +336,6 @@ def test_complex_groupby(self):
self.assertPandasEqual(expected7.toPandas(), result7.toPandas())

def test_complex_expressions(self):
from pyspark.sql.functions import col, sum

df = self.data
plus_one = self.python_plus_one
plus_two = self.pandas_scalar_plus_two
Expand Down Expand Up @@ -434,7 +409,6 @@ def test_complex_expressions(self):
self.assertPandasEqual(expected3, result3)

def test_retain_group_columns(self):
from pyspark.sql.functions import sum
with self.sql_conf({"spark.sql.retainGroupColumns": False}):
df = self.data
sum_udf = self.pandas_agg_sum_udf
Expand All @@ -444,17 +418,13 @@ def test_retain_group_columns(self):
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())

def test_array_type(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType

df = self.data

array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2'))
self.assertEquals(result1.first()['v2'], [1.0, 2.0])

def test_invalid_args(self):
from pyspark.sql.functions import mean

df = self.data
plus_one = self.python_plus_one
mean_udf = self.pandas_agg_mean_udf
Expand All @@ -478,9 +448,6 @@ def test_invalid_args(self):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()

def test_register_vectorized_udf_basic(self):
from pyspark.sql.functions import pandas_udf
from pyspark.rdd import PythonEvalType

sum_pandas_udf = pandas_udf(
lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)

Expand Down
Loading