Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,15 @@ def __init__(self, timezone):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone

def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import from_arrow_type, \
_check_series_convert_date, _check_series_localize_timestamps

s = arrow_column.to_pandas()
s = _check_series_convert_date(s, from_arrow_type(arrow_column.type))
s = _check_series_localize_timestamps(s, self._timezone)
return s

def dump_stream(self, iterator, stream):
"""
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
Expand All @@ -271,16 +280,11 @@ def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \
_check_dataframe_localize_timestamps
import pyarrow as pa
reader = pa.open_stream(stream)
schema = from_arrow_schema(reader.schema)

for batch in reader:
pdf = batch.to_pandas()
pdf = _check_dataframe_convert_date(pdf, schema)
pdf = _check_dataframe_localize_timestamps(pdf, self._timezone)
yield [c for _, c in pdf.iteritems()]
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]

def __repr__(self):
return "ArrowStreamPandasSerializer"
Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2267,6 +2267,31 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| 2| 1.1094003924504583|
+---+-------------------+

Alternatively, the user can define a function that takes two arguments.
In this case, the grouping key will be passed as the first argument and the data will
be passed as the second argument. The grouping key will be passed as a tuple of numpy
data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
This is useful when the user does not want to hardcode grouping key in the function.

>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> import pandas as pd # doctest: +SKIP
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v")) # doctest: +SKIP
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
... def mean_udf(key, pdf):
... # key is a tuple of one numpy.int64, which is the value
... # of 'id' for the current group
... return pd.DataFrame([key + (pdf.v.mean(),)])
>>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP
+---+---+
| id| v|
+---+---+
| 1|1.5|
| 2|6.0|
+---+---+

.. seealso:: :meth:`pyspark.sql.GroupedData.apply`

3. GROUPED_AGG
Expand Down
121 changes: 110 additions & 11 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3843,7 +3843,7 @@ def foo(df):
return df
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
@pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
def foo(k, v):
def foo(k, v, w):
return k


Expand Down Expand Up @@ -4416,20 +4416,45 @@ def test_supported_types(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
df = self.data.withColumn("arr", array(col("id")))

foo_udf = pandas_udf(
# Different forms of group map pandas UDF, results of these are the same

output_schema = StructType(
[StructField('id', LongType()),
StructField('v', IntegerType()),
StructField('arr', ArrayType(LongType())),
StructField('v1', DoubleType()),
StructField('v2', LongType())])

udf1 = pandas_udf(
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
StructType(
[StructField('id', LongType()),
StructField('v', IntegerType()),
StructField('arr', ArrayType(LongType())),
StructField('v1', DoubleType()),
StructField('v2', LongType())]),
output_schema,
PandasUDFType.GROUPED_MAP
)

result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
self.assertPandasEqual(expected, result)
udf2 = pandas_udf(
lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
output_schema,
PandasUDFType.GROUPED_MAP
)

udf3 = pandas_udf(
lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
output_schema,
PandasUDFType.GROUPED_MAP
)

result1 = df.groupby('id').apply(udf1).sort('id').toPandas()
expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True)

result2 = df.groupby('id').apply(udf2).sort('id').toPandas()
expected2 = expected1

result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
expected3 = expected1

self.assertPandasEqual(expected1, result1)
self.assertPandasEqual(expected2, result2)
self.assertPandasEqual(expected3, result3)

def test_register_grouped_map_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
Expand Down Expand Up @@ -4588,6 +4613,80 @@ def test_timestamp_dst(self):
result = df.groupby('time').apply(foo_udf).sort('time')
self.assertPandasEqual(df.toPandas(), result.toPandas())

def test_udf_with_key(self):
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
df = self.data
pdf = df.toPandas()

def foo1(key, pdf):
import numpy as np
assert type(key) == tuple
assert type(key[0]) == np.int64

return pdf.assign(v1=key[0],
v2=pdf.v * key[0],
v3=pdf.v * pdf.id,
v4=pdf.v * pdf.id.mean())

def foo2(key, pdf):
import numpy as np
assert type(key) == tuple
assert type(key[0]) == np.int64
assert type(key[1]) == np.int32

return pdf.assign(v1=key[0],
v2=key[1],
v3=pdf.v * key[0],
v4=pdf.v + key[1])

def foo3(key, pdf):
assert type(key) == tuple
assert len(key) == 0
return pdf.assign(v1=pdf.v * pdf.id)

# v2 is int because numpy.int64 * pd.Series<int32> results in pd.Series<int32>
# v3 is long because pd.Series<int64> * pd.Series<int32> results in pd.Series<int64>
udf1 = pandas_udf(
foo1,
'id long, v int, v1 long, v2 int, v3 long, v4 double',
PandasUDFType.GROUPED_MAP)

udf2 = pandas_udf(
foo2,
'id long, v int, v1 long, v2 int, v3 int, v4 int',
PandasUDFType.GROUPED_MAP)

udf3 = pandas_udf(
foo3,
'id long, v int, v1 long',
PandasUDFType.GROUPED_MAP)

# Test groupby column
result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
expected1 = pdf.groupby('id')\
.apply(lambda x: udf1.func((x.id.iloc[0],), x))\
.sort_values(['id', 'v']).reset_index(drop=True)
self.assertPandasEqual(expected1, result1)

# Test groupby expression
result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas()
expected2 = pdf.groupby(pdf.id % 2)\
.apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
.sort_values(['id', 'v']).reset_index(drop=True)
self.assertPandasEqual(expected2, result2)

# Test complex groupby
result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any negative test case when the number of columns specified in groupby is different from the definition of udf (foo2)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For end users, the misuse of this alternative functions could be common. For example, do we issue an appropriate error in the following cases?

  • result3 = df.groupby(df.id).apply(udf2).sort('id', 'v').toPandas()
  • result3 = df.groupby(df.id, df.v % 2, df.id).apply(udf2).sort('id', 'v').toPandas()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, any error in this case will be thrown as is from worker.py side which is read and redirect to users end via JVM. For instance:

from pyspark.sql.functions import pandas_udf, PandasUDFType
def test_func(key, pdf):
    assert len(key) == 0
    return pdf

udf1 = pandas_udf(test_func, "id long, v1 double", PandasUDFType.GROUPED_MAP)
spark.range(10).groupby('id').apply(udf1).sort('id').show()
18/09/04 14:22:52 ERROR TaskSetManager: Task 1 in stage 0.0 failed 1 times; aborting job
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/.../spark/python/pyspark/sql/dataframe.py", line 378, in show
    print(self._jdf.showString(n, 20, vertical))
  File "/.../spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
  File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco
    return f(*a, **kw)
  File "/.../spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o68.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 0.0 failed 1 times, most recent failure: Lost task 1.0 in stage 0.0 (TID 1, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/.../python/lib/pyspark.zip/pyspark/worker.py", line 353, in main
    process()
  File "/.../python/lib/pyspark.zip/pyspark/worker.py", line 348, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 242, in <lambda>
    func = lambda _, it: map(mapper, it)
  File "<string>", line 1, in <lambda>
  File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 110, in wrapped
    result = f(key, pd.concat(value_series, axis=1))
  File "/.../spark/python/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<stdin>", line 2, in test_func
AssertionError

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:418)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:372)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at scala.collection.convert.Wrappers$IteratorWrapper.hasNext(Wrappers.scala:30)
	at org.spark_project.guava.collect.Ordering.leastOf(Ordering.java:628)
	at org.apache.spark.util.collection.Utils$.takeOrdered(Utils.scala:37)
	at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1427)
	at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1424)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:48)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:128)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$7.apply(Executor.scala:367)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1348)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:373)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1822)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1810)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1809)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1809)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2043)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1992)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1981)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2158)
	at org.apache.spark.rdd.RDD$$anonfun$reduce$1.apply(RDD.scala:1029)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
	at org.apache.spark.rdd.RDD.reduce(RDD.scala:1011)
	at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1.apply(RDD.scala:1433)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
	at org.apache.spark.rdd.RDD.takeOrdered(RDD.scala:1420)
	at org.apache.spark.sql.execution.TakeOrderedAndProjectExec.executeCollect(limit.scala:207)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3384)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
	at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3365)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3364)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2545)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2759)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:255)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:292)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 353, in main
    process()
  File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 348, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 242, in <lambda>
    func = lambda _, it: map(mapper, it)
  File "<string>", line 1, in <lambda>
  File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 110, in wrapped
    result = f(key, pd.concat(value_series, axis=1))
  File "/.../spark/python/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<stdin>", line 2, in test_func
AssertionError

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:418)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:372)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at scala.collection.convert.Wrappers$IteratorWrapper.hasNext(Wrappers.scala:30)
	at org.spark_project.guava.collect.Ordering.leastOf(Ordering.java:628)
	at org.apache.spark.util.collection.Utils$.takeOrdered(Utils.scala:37)
	at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1427)
	at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1424)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:48)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:128)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$7.apply(Executor.scala:367)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1348)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:373)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more

expected3 = pdf.groupby([pdf.id, pdf.v % 2])\
.apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\
.sort_values(['id', 'v']).reset_index(drop=True)
self.assertPandasEqual(expected3, result3)

# Test empty groupby
result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
expected4 = udf3.func((), pdf)
self.assertPandasEqual(expected4, result4)


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down
45 changes: 38 additions & 7 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,6 +1695,19 @@ def from_arrow_schema(arrow_schema):
for field in arrow_schema])


def _check_series_convert_date(series, data_type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need some documents here as well?

"""
Cast the series to datetime.date if it's a date type, otherwise returns the original series.

:param series: pandas.Series
:param data_type: a Spark data type for the series
"""
if type(data_type) == DateType:
return series.dt.date
else:
return series


def _check_dataframe_convert_date(pdf, schema):
""" Correct date type value to use datetime.date.

Expand All @@ -1705,8 +1718,7 @@ def _check_dataframe_convert_date(pdf, schema):
:param schema: a Spark schema of the pandas.DataFrame
"""
for field in schema:
if type(field.dataType) == DateType:
pdf[field.name] = pdf[field.name].dt.date
pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType)
return pdf


Expand All @@ -1725,6 +1737,29 @@ def _get_local_timezone():
return os.environ.get('TZ', 'dateutil/:')


def _check_series_localize_timestamps(s, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.

If the input series is not a timestamp series, then the same series is returned. If the input
series is a timestamp series, then a converted series is returned.

:param s: pandas.Series
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series that have been converted to tz-naive
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()

from pandas.api.types import is_datetime64tz_dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have tests for tese in tests.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function itself _check_series_localize_timestamps doesn't have a unit test, but it's called in various arrow/pandas_udf tests related to timestamps.

tz = timezone or _get_local_timezone()
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert(tz).dt.tz_localize(None)
else:
return s


def _check_dataframe_localize_timestamps(pdf, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
Expand All @@ -1736,12 +1771,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone):
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()

from pandas.api.types import is_datetime64tz_dtype
tz = timezone or _get_local_timezone()
for column, series in pdf.iteritems():
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(series.dtype):
pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None)
pdf[column] = _check_series_localize_timestamps(series, timezone)
return pdf


Expand Down
19 changes: 7 additions & 12 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
"""
User-defined function related classes and functions
"""
import sys
import inspect
import functools

from pyspark import SparkContext, since
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \
_parse_datatype_string, to_arrow_type, to_arrow_schema
from pyspark.util import _get_argspec

__all__ = ["UDFRegistration"]

Expand All @@ -41,18 +44,10 @@ def _create_udf(f, returnType, evalType):
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):

import inspect
import sys
from pyspark.sql.utils import require_minimum_pyarrow_version

require_minimum_pyarrow_version()

if sys.version_info[0] < 3:
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
# See SPARK-23569.
argspec = inspect.getargspec(f)
else:
argspec = inspect.getfullargspec(f)
argspec = _get_argspec(f)

if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
argspec.varargs is None:
Expand All @@ -61,11 +56,11 @@ def _create_udf(f, returnType, evalType):
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
)

if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1:
if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
and len(argspec.args) not in (1, 2):
raise ValueError(
"Invalid function: pandas_udfs with function type GROUPED_MAP "
"must take a single arg that is a pandas DataFrame."
)
"must take either one argument (data) or two arguments (key, data).")

# Set the name of the UserDefinedFunction object to be the name of function f
udf_obj = UserDefinedFunction(
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys
import inspect
from py4j.protocol import Py4JJavaError

__all__ = []
Expand Down Expand Up @@ -45,6 +48,19 @@ def _exception_message(excp):
return str(excp)


def _get_argspec(f):
"""
Get argspec of a function. Supports both Python 2 and Python 3.
"""
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
# See SPARK-23569.
if sys.version_info[0] < 3:
argspec = inspect.getargspec(f)
else:
argspec = inspect.getfullargspec(f)
return argspec


if __name__ == "__main__":
import doctest
(failure_count, test_count) = doctest.testmod()
Expand Down
Loading