-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-48755][SS][PYTHON] transformWithState pyspark base implementation and ValueState support #47133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-48755][SS][PYTHON] transformWithState pyspark base implementation and ValueState support #47133
Changes from 48 commits
b1175e4
0a98ed8
8e2b193
16e4c17
92ef716
c3eaf38
609d94e
a27f9d9
684939b
7f65fbd
c25d7da
9c8c616
c641192
8d3da4e
cc9bf95
f7df2dc
27cd169
3b5b3e5
5d910d8
df859ab
654f2f6
38832a6
0585ac0
0ee5029
6232c81
41f8234
d57633f
df9ea9e
68f7a7e
ca5216b
c9e3a7c
2320805
6e5de2e
200ec5e
dd3e46b
e8360d4
82983af
49dbc16
d4e04ea
e108f60
bae26c2
d96fa9e
92531db
d507793
5dcb4c8
37be02a
f63687f
263c087
c7b0a4f
c80b292
81276f3
5886b5c
23e54b4
2ba4fd0
2a9c20b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,3 +44,4 @@ This page gives an overview of all public Spark SQL API. | |
| variant_val | ||
| protobuf | ||
| datasource | ||
| stateful_processor | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| .. Licensed to the Apache Software Foundation (ASF) under one | ||
| or more contributor license agreements. See the NOTICE file | ||
| distributed with this work for additional information | ||
| regarding copyright ownership. The ASF licenses this file | ||
| to you under the Apache License, Version 2.0 (the | ||
| "License"); you may not use this file except in compliance | ||
| with the License. You may obtain a copy of the License at | ||
|
|
||
| .. http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| .. Unless required by applicable law or agreed to in writing, | ||
| software distributed under the License is distributed on an | ||
| "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations | ||
| under the License. | ||
|
|
||
|
|
||
| ================== | ||
| Stateful Processor | ||
| ================== | ||
| .. currentmodule:: pyspark.sql.streaming | ||
|
|
||
| .. autosummary:: | ||
| :toctree: api/ | ||
|
|
||
| StatefulProcessor.init | ||
| StatefulProcessor.handleInputRows | ||
| StatefulProcessor.close |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,14 +15,19 @@ | |
| # limitations under the License. | ||
| # | ||
| import sys | ||
| from typing import List, Union, TYPE_CHECKING, cast | ||
| from typing import Any, Iterator, List, Union, TYPE_CHECKING, cast | ||
| import warnings | ||
|
|
||
| from pyspark.errors import PySparkTypeError | ||
| from pyspark.util import PythonEvalType | ||
| from pyspark.sql.column import Column | ||
| from pyspark.sql.dataframe import DataFrame | ||
| from pyspark.sql.streaming.state import GroupStateTimeout | ||
| from pyspark.sql.streaming.stateful_processor_api_client import ( | ||
| StatefulProcessorApiClient, | ||
| StatefulProcessorHandleState, | ||
| ) | ||
| from pyspark.sql.streaming.stateful_processor import StatefulProcessor, StatefulProcessorHandle | ||
| from pyspark.sql.types import StructType, _parse_datatype_string | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -33,6 +38,7 @@ | |
| PandasCogroupedMapFunction, | ||
| ArrowGroupedMapFunction, | ||
| ArrowCogroupedMapFunction, | ||
| DataFrameLike as PandasDataFrameLike, | ||
| ) | ||
| from pyspark.sql.group import GroupedData | ||
|
|
||
|
|
@@ -358,6 +364,152 @@ def applyInPandasWithState( | |
| ) | ||
| return DataFrame(jdf, self.session) | ||
|
|
||
| def transformWithStateInPandas( | ||
| self, | ||
| statefulProcessor: StatefulProcessor, | ||
| outputStructType: Union[StructType, str], | ||
| outputMode: str, | ||
| timeMode: str, | ||
| ) -> DataFrame: | ||
| """ | ||
| Invokes methods defined in the stateful processor used in arbitrary state API v2. It | ||
| requires protobuf as a dependency to transmit state messages/data. We allow the user to act | ||
| on per-group set of input rows along with keyed state and the user can choose to | ||
| output/return 0 or more rows. | ||
|
|
||
| For a streaming dataframe, we will repeatedly invoke the interface methods for new rows | ||
| in each trigger and the user's state/state variables will be stored persistently across | ||
| invocations. | ||
|
|
||
| The `statefulProcessor` should be a Python class that implements the interface defined in | ||
| pyspark.sql.streaming.stateful_processor.StatefulProcessor. | ||
|
|
||
| The `outputStructType` should be a :class:`StructType` describing the schema of all | ||
| elements in the returned value, `pandas.DataFrame`. The column labels of all elements in | ||
| returned `pandas.DataFrame` must either match the field names in the defined schema if | ||
| specified as strings, or match the field data types by position if not strings, | ||
| e.g. integer indices. | ||
|
|
||
| The size of each `pandas.DataFrame` in both the input and output can be arbitrary. The | ||
| number of `pandas.DataFrame` in both the input and output can also be arbitrary. | ||
|
|
||
| .. versionadded:: 4.0.0 | ||
|
|
||
| Parameters | ||
| ---------- | ||
| statefulProcessor : :class:`pyspark.sql.streaming.stateful_processor.StatefulProcessor` | ||
| Instance of StatefulProcessor whose functions will be invoked by the operator. | ||
| outputStructType : :class:`pyspark.sql.types.DataType` or str | ||
| The type of the output records. The value can be either a | ||
| :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. | ||
| outputMode : str | ||
| The output mode of the stateful processor. | ||
| timeMode : str | ||
| The time mode semantics of the stateful processor for timers and TTL. | ||
|
|
||
HyukjinKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Examples | ||
| -------- | ||
| >>> from typing import Iterator | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ... | ||
| >>> import pandas as pd # doctest: +SKIP | ||
| ... | ||
| >>> from pyspark.sql import Row | ||
| >>> from pyspark.sql.functions import col, split | ||
| >>> from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle | ||
| >>> from pyspark.sql.types import IntegerType, LongType, StringType, StructField, StructType | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| >>> spark.conf.set("spark.sql.streaming.stateStore.providerClass", | ||
| ... "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider") | ||
| ... # Below is a simple example of a stateful processor that counts the number of violations | ||
| ... # for a set of temperature sensors. A violation is defined when the temperature is above | ||
| ... # 100. | ||
| ... # The input data is a DataFrame with the following schema: | ||
| ... # `id: string, temperature: long`. | ||
| ... # The output schema and state schema are defined as below. | ||
| >>> output_schema = StructType([ | ||
| ... StructField("id", StringType(), True), | ||
| ... StructField("count", IntegerType(), True) | ||
| ... ]) | ||
| >>> state_schema = StructType([ | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ... StructField("value", IntegerType(), True) | ||
| ... ]) | ||
| >>> class SimpleStatefulProcessor(StatefulProcessor): | ||
| ... def init(self, handle: StatefulProcessorHandle): | ||
| ... self.num_violations_state = handle.getValueState("numViolations", state_schema) | ||
| ... | ||
| ... def handleInputRows(self, key, rows): | ||
HyukjinKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ... new_violations = 0 | ||
| ... count = 0 | ||
| ... exists = self.num_violations_state.exists() | ||
| ... if exists: | ||
| ... existing_violations_pdf = self.num_violations_state.get() | ||
|
||
| ... existing_violations = existing_violations_pdf.get("value")[0] | ||
| ... else: | ||
| ... existing_violations = 0 | ||
| ... for pdf in rows: | ||
| ... pdf_count = pdf.count() | ||
| ... count += pdf_count.get('temperature') | ||
| ... violations_pdf = pdf.loc[pdf['temperature'] > 100] | ||
| ... new_violations += violations_pdf.count().get('temperature') | ||
| ... updated_violations = new_violations + existing_violations | ||
| ... self.num_violations_state.update((updated_violations,)) | ||
| ... yield pd.DataFrame({'id': key, 'count': count}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the explanation is to produce the number of violations instead of the number of inputs. This doesn't follow the explanation. |
||
| ... | ||
| ... def close(self) -> None: | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ... pass | ||
| ... | ||
| >>> df.groupBy("value").transformWithStateInPandas(statefulProcessor = | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ... SimpleStatefulProcessor(), outputStructType=output_schema, outputMode="Update", | ||
| ... timeMode="None") # doctest: +SKIP | ||
|
|
||
| Notes | ||
| ----- | ||
| This function requires a full shuffle. | ||
|
|
||
| This API is experimental. | ||
| """ | ||
|
|
||
| from pyspark.sql import GroupedData | ||
| from pyspark.sql.functions import pandas_udf | ||
|
|
||
| assert isinstance(self, GroupedData) | ||
|
|
||
| def transformWithStateUDF( | ||
| statefulProcessorApiClient: StatefulProcessorApiClient, | ||
| key: Any, | ||
| inputRows: Iterator["PandasDataFrameLike"], | ||
| ) -> Iterator["PandasDataFrameLike"]: | ||
| handle = StatefulProcessorHandle(statefulProcessorApiClient) | ||
|
|
||
| if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED: | ||
| statefulProcessor.init(handle) | ||
| statefulProcessorApiClient.set_handle_state( | ||
| StatefulProcessorHandleState.INITIALIZED | ||
| ) | ||
|
|
||
| statefulProcessorApiClient.set_implicit_key(key) | ||
| result = statefulProcessor.handleInputRows(key, inputRows) | ||
|
|
||
| return result | ||
|
|
||
| if isinstance(outputStructType, str): | ||
| outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) | ||
|
|
||
| udf = pandas_udf( | ||
| transformWithStateUDF, # type: ignore | ||
| returnType=outputStructType, | ||
| functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, | ||
| ) | ||
| df = self._df | ||
| udf_column = udf(*[df[col] for col in df.columns]) | ||
|
|
||
| jdf = self._jgd.transformWithStateInPandas( | ||
| udf_column._jc.expr(), | ||
| self.session._jsparkSession.parseDataType(outputStructType.json()), | ||
| outputMode, | ||
| timeMode, | ||
| ) | ||
| return DataFrame(jdf, self.session) | ||
|
|
||
| def applyInArrow( | ||
| self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str] | ||
| ) -> "DataFrame": | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,9 +19,16 @@ | |
| Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. | ||
| """ | ||
|
|
||
| from itertools import groupby | ||
| from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError | ||
| from pyspark.loose_version import LooseVersion | ||
| from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer | ||
| from pyspark.serializers import ( | ||
| Serializer, | ||
| read_int, | ||
| write_int, | ||
| UTF8Deserializer, | ||
| CPickleSerializer, | ||
| ) | ||
| from pyspark.sql.pandas.types import ( | ||
| from_arrow_type, | ||
| to_arrow_type, | ||
|
|
@@ -1116,3 +1123,69 @@ def init_stream_yield_batches(batches): | |
| batches_to_write = init_stream_yield_batches(serialize_batches()) | ||
|
|
||
| return ArrowStreamSerializer.dump_stream(self, batches_to_write, stream) | ||
|
|
||
|
|
||
| class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): | ||
bogao007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Serializer used by Python worker to evaluate UDF for transformWithStateInPandasSerializer. | ||
bogao007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Parameters | ||
| ---------- | ||
| timezone : str | ||
| A timezone to respect when handling timestamp values | ||
| safecheck : bool | ||
| If True, conversion from Arrow to Pandas checks for overflow/truncation | ||
| assign_cols_by_name : bool | ||
| If True, then Pandas DataFrames will get columns by name | ||
| arrow_max_records_per_batch : int | ||
| Limit of the number of records that can be written to a single ArrowRecordBatch in memory. | ||
| """ | ||
|
|
||
| def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch): | ||
| super(TransformWithStateInPandasSerializer, self).__init__( | ||
| timezone, safecheck, assign_cols_by_name | ||
| ) | ||
| self.arrow_max_records_per_batch = arrow_max_records_per_batch | ||
| self.key_offsets = None | ||
|
|
||
| def load_stream(self, stream): | ||
| """ | ||
| Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunk, and | ||
| convert the data into a list of pandas.Series. | ||
|
|
||
| Please refer the doc of inner function `generate_data_batches` for more details how | ||
| this function works in overall. | ||
| """ | ||
| import pyarrow as pa | ||
|
|
||
| def generate_data_batches(batches): | ||
| """ | ||
| Deserialize ArrowRecordBatches and return a generator of pandas.Series list. | ||
|
|
||
| The deserialization logic assumes that Arrow RecordBatches contain the data with the | ||
| ordering that data chunks for same grouping key will appear sequentially. | ||
|
|
||
| This function must avoid materializing multiple Arrow RecordBatches into memory at the | ||
| same time. And data chunks from the same grouping key should appear sequentially. | ||
| """ | ||
| for batch in batches: | ||
| data_pandas = [ | ||
| self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns() | ||
| ] | ||
| key_series = [data_pandas[o] for o in self.key_offsets] | ||
| batch_key = tuple(s[0] for s in key_series) | ||
| yield (batch_key, data_pandas) | ||
|
|
||
| _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) | ||
| data_batches = generate_data_batches(_batches) | ||
|
|
||
| for k, g in groupby(data_batches, key=lambda x: x[0]): | ||
| yield (k, g) | ||
|
|
||
| def dump_stream(self, iterator, stream): | ||
| """ | ||
| Read through an iterator of (iterator of pandas DataFram), serialize them to Arrow | ||
|
||
| RecordBatches, and write batches to stream. | ||
| """ | ||
| result = [(b, t) for x in iterator for y, t in x for b in y] | ||
| super().dump_stream(result, stream) | ||
Uh oh!
There was an error while loading. Please reload this page.