-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-48755][SS][PYTHON] transformWithState pyspark base implementation and ValueState support #47133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-48755][SS][PYTHON] transformWithState pyspark base implementation and ValueState support #47133
Changes from all commits
b1175e4
0a98ed8
8e2b193
16e4c17
92ef716
c3eaf38
609d94e
a27f9d9
684939b
7f65fbd
c25d7da
9c8c616
c641192
8d3da4e
cc9bf95
f7df2dc
27cd169
3b5b3e5
5d910d8
df859ab
654f2f6
38832a6
0585ac0
0ee5029
6232c81
41f8234
d57633f
df9ea9e
68f7a7e
ca5216b
c9e3a7c
2320805
6e5de2e
200ec5e
dd3e46b
e8360d4
82983af
49dbc16
d4e04ea
e108f60
bae26c2
d96fa9e
92531db
d507793
5dcb4c8
37be02a
f63687f
263c087
c7b0a4f
c80b292
81276f3
5886b5c
23e54b4
2ba4fd0
2a9c20b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,3 +44,4 @@ This page gives an overview of all public Spark SQL API. | |
| variant_val | ||
| protobuf | ||
| datasource | ||
| stateful_processor | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| .. Licensed to the Apache Software Foundation (ASF) under one | ||
| or more contributor license agreements. See the NOTICE file | ||
| distributed with this work for additional information | ||
| regarding copyright ownership. The ASF licenses this file | ||
| to you under the Apache License, Version 2.0 (the | ||
| "License"); you may not use this file except in compliance | ||
| with the License. You may obtain a copy of the License at | ||
|
|
||
| .. http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| .. Unless required by applicable law or agreed to in writing, | ||
| software distributed under the License is distributed on an | ||
| "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations | ||
| under the License. | ||
|
|
||
|
|
||
| ================== | ||
| Stateful Processor | ||
| ================== | ||
| .. currentmodule:: pyspark.sql.streaming | ||
|
|
||
| .. autosummary:: | ||
| :toctree: api/ | ||
|
|
||
| StatefulProcessor.init | ||
| StatefulProcessor.handleInputRows | ||
| StatefulProcessor.close |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,14 +15,19 @@ | |
| # limitations under the License. | ||
| # | ||
| import sys | ||
| from typing import List, Union, TYPE_CHECKING, cast | ||
| from typing import Any, Iterator, List, Union, TYPE_CHECKING, cast | ||
| import warnings | ||
|
|
||
| from pyspark.errors import PySparkTypeError | ||
| from pyspark.util import PythonEvalType | ||
| from pyspark.sql.column import Column | ||
| from pyspark.sql.dataframe import DataFrame | ||
| from pyspark.sql.streaming.state import GroupStateTimeout | ||
| from pyspark.sql.streaming.stateful_processor_api_client import ( | ||
| StatefulProcessorApiClient, | ||
| StatefulProcessorHandleState, | ||
| ) | ||
| from pyspark.sql.streaming.stateful_processor import StatefulProcessor, StatefulProcessorHandle | ||
| from pyspark.sql.types import StructType, _parse_datatype_string | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -33,6 +38,7 @@ | |
| PandasCogroupedMapFunction, | ||
| ArrowGroupedMapFunction, | ||
| ArrowCogroupedMapFunction, | ||
| DataFrameLike as PandasDataFrameLike, | ||
| ) | ||
| from pyspark.sql.group import GroupedData | ||
|
|
||
|
|
@@ -358,6 +364,172 @@ def applyInPandasWithState( | |
| ) | ||
| return DataFrame(jdf, self.session) | ||
|
|
||
| def transformWithStateInPandas( | ||
| self, | ||
| statefulProcessor: StatefulProcessor, | ||
| outputStructType: Union[StructType, str], | ||
| outputMode: str, | ||
| timeMode: str, | ||
| ) -> DataFrame: | ||
| """ | ||
| Invokes methods defined in the stateful processor used in arbitrary state API v2. It | ||
| requires protobuf, pandas and pyarrow as dependencies to process input/state data. We | ||
| allow the user to act on per-group set of input rows along with keyed state and the user | ||
| can choose to output/return 0 or more rows. | ||
|
|
||
| For a streaming dataframe, we will repeatedly invoke the interface methods for new rows | ||
| in each trigger and the user's state/state variables will be stored persistently across | ||
| invocations. | ||
|
|
||
| The `statefulProcessor` should be a Python class that implements the interface defined in | ||
| :class:`StatefulProcessor`. | ||
|
|
||
| The `outputStructType` should be a :class:`StructType` describing the schema of all | ||
| elements in the returned value, `pandas.DataFrame`. The column labels of all elements in | ||
| returned `pandas.DataFrame` must either match the field names in the defined schema if | ||
| specified as strings, or match the field data types by position if not strings, | ||
| e.g. integer indices. | ||
|
|
||
| The size of each `pandas.DataFrame` in both the input and output can be arbitrary. The | ||
| number of `pandas.DataFrame` in both the input and output can also be arbitrary. | ||
|
|
||
| .. versionadded:: 4.0.0 | ||
|
|
||
| Parameters | ||
| ---------- | ||
| statefulProcessor : :class:`pyspark.sql.streaming.stateful_processor.StatefulProcessor` | ||
| Instance of StatefulProcessor whose functions will be invoked by the operator. | ||
| outputStructType : :class:`pyspark.sql.types.DataType` or str | ||
| The type of the output records. The value can be either a | ||
| :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. | ||
| outputMode : str | ||
| The output mode of the stateful processor. | ||
| timeMode : str | ||
| The time mode semantics of the stateful processor for timers and TTL. | ||
|
|
||
HyukjinKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Examples | ||
| -------- | ||
| >>> from typing import Iterator | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ... | ||
| >>> import pandas as pd # doctest: +SKIP | ||
| ... | ||
| >>> from pyspark.sql import Row | ||
| >>> from pyspark.sql.functions import col, split | ||
| >>> from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle | ||
| >>> from pyspark.sql.types import IntegerType, LongType, StringType, StructField, StructType | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ... | ||
| >>> spark.conf.set("spark.sql.streaming.stateStore.providerClass", | ||
| ... "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider") | ||
| ... # Below is a simple example to find erroneous sensors from temperature sensor data. The | ||
| ... # processor returns a count of total readings, while keeping erroneous reading counts | ||
| ... # in streaming state. A violation is defined when the temperature is above 100. | ||
| ... # The input data is a DataFrame with the following schema: | ||
| ... # `id: string, temperature: long`. | ||
| ... # The output schema and state schema are defined as below. | ||
| >>> output_schema = StructType([ | ||
| ... StructField("id", StringType(), True), | ||
| ... StructField("count", IntegerType(), True) | ||
| ... ]) | ||
| >>> state_schema = StructType([ | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ... StructField("value", IntegerType(), True) | ||
| ... ]) | ||
| >>> class SimpleStatefulProcessor(StatefulProcessor): | ||
| ... def init(self, handle: StatefulProcessorHandle): | ||
| ... self.num_violations_state = handle.getValueState("numViolations", state_schema) | ||
| ... | ||
| ... def handleInputRows(self, key, rows): | ||
HyukjinKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ... new_violations = 0 | ||
| ... count = 0 | ||
| ... exists = self.num_violations_state.exists() | ||
| ... if exists: | ||
| ... existing_violations_row = self.num_violations_state.get() | ||
| ... existing_violations = existing_violations_row[0] | ||
| ... else: | ||
| ... existing_violations = 0 | ||
| ... for pdf in rows: | ||
| ... pdf_count = pdf.count() | ||
| ... count += pdf_count.get('temperature') | ||
| ... violations_pdf = pdf.loc[pdf['temperature'] > 100] | ||
| ... new_violations += violations_pdf.count().get('temperature') | ||
| ... updated_violations = new_violations + existing_violations | ||
| ... self.num_violations_state.update((updated_violations,)) | ||
| ... yield pd.DataFrame({'id': key, 'count': count}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the explanation is to produce the number of violations instead of the number of inputs. This doesn't follow the explanation. |
||
| ... | ||
| ... def close(self) -> None: | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ... pass | ||
|
|
||
| Input DataFrame: | ||
| +---+-----------+ | ||
| | id|temperature| | ||
| +---+-----------+ | ||
| | 0| 123| | ||
| | 0| 23| | ||
| | 1| 33| | ||
| | 1| 188| | ||
| | 1| 88| | ||
| +---+-----------+ | ||
|
|
||
| >>> df.groupBy("value").transformWithStateInPandas(statefulProcessor = | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ... SimpleStatefulProcessor(), outputStructType=output_schema, outputMode="Update", | ||
| ... timeMode="None") # doctest: +SKIP | ||
|
|
||
| Output DataFrame: | ||
| +---+-----+ | ||
| | id|count| | ||
| +---+-----+ | ||
| | 0| 2| | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't the desired output (0, 1), (1, 1)? |
||
| | 1| 3| | ||
| +---+-----+ | ||
|
|
||
| Notes | ||
| ----- | ||
| This function requires a full shuffle. | ||
|
|
||
| This API is experimental. | ||
| """ | ||
|
|
||
| from pyspark.sql import GroupedData | ||
| from pyspark.sql.functions import pandas_udf | ||
|
|
||
| assert isinstance(self, GroupedData) | ||
|
|
||
| def transformWithStateUDF( | ||
| statefulProcessorApiClient: StatefulProcessorApiClient, | ||
| key: Any, | ||
| inputRows: Iterator["PandasDataFrameLike"], | ||
| ) -> Iterator["PandasDataFrameLike"]: | ||
| handle = StatefulProcessorHandle(statefulProcessorApiClient) | ||
|
|
||
| if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED: | ||
| statefulProcessor.init(handle) | ||
| statefulProcessorApiClient.set_handle_state( | ||
| StatefulProcessorHandleState.INITIALIZED | ||
| ) | ||
|
|
||
| statefulProcessorApiClient.set_implicit_key(key) | ||
| result = statefulProcessor.handleInputRows(key, inputRows) | ||
|
|
||
| return result | ||
|
|
||
| if isinstance(outputStructType, str): | ||
| outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) | ||
|
|
||
| udf = pandas_udf( | ||
| transformWithStateUDF, # type: ignore | ||
| returnType=outputStructType, | ||
| functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, | ||
| ) | ||
| df = self._df | ||
| udf_column = udf(*[df[col] for col in df.columns]) | ||
|
|
||
| jdf = self._jgd.transformWithStateInPandas( | ||
| udf_column._jc.expr(), | ||
| self.session._jsparkSession.parseDataType(outputStructType.json()), | ||
| outputMode, | ||
| timeMode, | ||
| ) | ||
| return DataFrame(jdf, self.session) | ||
|
|
||
| def applyInArrow( | ||
| self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str] | ||
| ) -> "DataFrame": | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.