-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF #20295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3b1159e
0a1aabb
f842b0a
9bcd0e5
b5b7e7d
fd52c83
ed7f40f
0a0f1ad
25c0a9e
b9f7fc3
f602a70
aafe58a
5740006
e4e5921
722ed50
c74ed05
59bdf20
d51bc2e
4b61f52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1695,6 +1695,19 @@ def from_arrow_schema(arrow_schema): | |
| for field in arrow_schema]) | ||
|
|
||
|
|
||
| def _check_series_convert_date(series, data_type): | ||
|
||
| """ | ||
| Cast the series to datetime.date if it's a date type, otherwise returns the original series. | ||
|
|
||
| :param series: pandas.Series | ||
| :param data_type: a Spark data type for the series | ||
| """ | ||
| if type(data_type) == DateType: | ||
| return series.dt.date | ||
| else: | ||
| return series | ||
|
|
||
|
|
||
| def _check_dataframe_convert_date(pdf, schema): | ||
| """ Correct date type value to use datetime.date. | ||
|
|
||
|
|
@@ -1705,8 +1718,7 @@ def _check_dataframe_convert_date(pdf, schema): | |
| :param schema: a Spark schema of the pandas.DataFrame | ||
| """ | ||
| for field in schema: | ||
| if type(field.dataType) == DateType: | ||
| pdf[field.name] = pdf[field.name].dt.date | ||
| pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType) | ||
| return pdf | ||
|
|
||
|
|
||
|
|
@@ -1725,6 +1737,29 @@ def _get_local_timezone(): | |
| return os.environ.get('TZ', 'dateutil/:') | ||
|
|
||
|
|
||
| def _check_series_localize_timestamps(s, timezone): | ||
| """ | ||
| Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone. | ||
|
|
||
| If the input series is not a timestamp series, then the same series is returned. If the input | ||
| series is a timestamp series, then a converted series is returned. | ||
|
|
||
| :param s: pandas.Series | ||
| :param timezone: the timezone to convert. if None then use local timezone | ||
| :return pandas.Series that have been converted to tz-naive | ||
| """ | ||
| from pyspark.sql.utils import require_minimum_pandas_version | ||
| require_minimum_pandas_version() | ||
|
|
||
| from pandas.api.types import is_datetime64tz_dtype | ||
|
||
| tz = timezone or _get_local_timezone() | ||
| # TODO: handle nested timestamps, such as ArrayType(TimestampType())? | ||
| if is_datetime64tz_dtype(s.dtype): | ||
| return s.dt.tz_convert(tz).dt.tz_localize(None) | ||
| else: | ||
| return s | ||
|
|
||
|
|
||
| def _check_dataframe_localize_timestamps(pdf, timezone): | ||
| """ | ||
| Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone | ||
|
|
@@ -1736,12 +1771,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone): | |
| from pyspark.sql.utils import require_minimum_pandas_version | ||
| require_minimum_pandas_version() | ||
|
|
||
| from pandas.api.types import is_datetime64tz_dtype | ||
| tz = timezone or _get_local_timezone() | ||
| for column, series in pdf.iteritems(): | ||
| # TODO: handle nested timestamps, such as ArrayType(TimestampType())? | ||
| if is_datetime64tz_dtype(series.dtype): | ||
| pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None) | ||
| pdf[column] = _check_series_localize_timestamps(series, timezone) | ||
| return pdf | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any negative test case when the number of columns specified in groupby is different from the definition of udf (foo2)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For end users, the misuse of this alternative functions could be common. For example, do we issue an appropriate error in the following cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, any error in this case will be thrown as is from worker.py side which is read and redirect to users end via JVM. For instance: