Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
512958b
Basework
HeartSaVioR Jul 12, 2022
d36373b
Add Python implementation
HyukjinKwon Jul 14, 2022
f754fd9
Reorder key attributes from deduplicated data attributes
HyukjinKwon Jul 27, 2022
5194e0c
Apply suggestions from code review
HyukjinKwon Jul 27, 2022
1301ee5
Refactoring a bit to respect the column order
HyukjinKwon Aug 11, 2022
135a826
WIP Changes to execute in pipelined manner
HeartSaVioR Aug 15, 2022
9282e5c
WIP further optimization
HeartSaVioR Aug 18, 2022
a792c98
WIP comments for more tunes
HeartSaVioR Aug 18, 2022
27e7af9
WIP further tune...
HeartSaVioR Aug 18, 2022
04a6b98
WIP done more tune! didn't do any of pandas/arrow side tunes
HeartSaVioR Aug 18, 2022
765f4d3
WIP avoid adding additional empty row for state, empty row will be ad…
HeartSaVioR Aug 19, 2022
9e11225
WIP remove debug log
HeartSaVioR Aug 19, 2022
f33d978
WIP hack around to see the possibility of perf gain on binpacking
HeartSaVioR Aug 27, 2022
8604fdf
WIP proper work to apply binpacking on python worker -> executor
HeartSaVioR Aug 27, 2022
0d024e0
WIP fix silly bug
HeartSaVioR Aug 27, 2022
43c623b
WIP another silly bugfix on migration
HeartSaVioR Aug 27, 2022
af1725a
WIP apply binpacking for executor -> python worker as well
HeartSaVioR Aug 27, 2022
31e9687
WIP fix silly bug
HeartSaVioR Aug 27, 2022
cad77a2
WIP fix another silly bug
HeartSaVioR Aug 27, 2022
c3da996
WIP batching per specified size, with sampling
HeartSaVioR Aug 29, 2022
cfb2780
WIP introduce DBR-only change
HeartSaVioR Aug 29, 2022
228b140
WIP debugging now...
HeartSaVioR Aug 29, 2022
ee4ed57
WIP still debugging... weirdness happened
HeartSaVioR Aug 30, 2022
4045ab3
WIP small fix
HeartSaVioR Aug 30, 2022
2d115ab
WIP fix a serious bug... make sure all columns in Arrow RecordBatch h…
HeartSaVioR Aug 30, 2022
3e7d785
WIP strengthen test
HeartSaVioR Aug 30, 2022
029dae7
WIP documenting the changes for pipelining and bin-packing... not yet…
HeartSaVioR Sep 2, 2022
d7ecaf9
WIP sync
HeartSaVioR Sep 2, 2022
6a6dd20
WIP start with is_last_chunk since it's easier to implement... severa…
HeartSaVioR Sep 2, 2022
5cfd59c
WIP adjust the test code to make test pass with multiple calls
HeartSaVioR Sep 2, 2022
63f8f87
WIP refactor a bit... just extract the abstract classes to explicit ones
HeartSaVioR Sep 5, 2022
6e772cd
WIP iterator of DatFrame done! updated tests and they all passed
HeartSaVioR Sep 5, 2022
00836b5
WIP FIX pyspark side test failure
HeartSaVioR Sep 6, 2022
5fdde94
WIP sort out codebase a bit
HeartSaVioR Sep 14, 2022
e7ad043
WIP no batch query support in applyInPandasWithState
HeartSaVioR Sep 6, 2022
5070b81
WIP address some missed things
HeartSaVioR Sep 6, 2022
1b919b8
WIP remove comments which are obsolete or won't be addressed
HeartSaVioR Sep 7, 2022
198fc17
WIP change the return type of user function to Iterator[DataFrame]
HeartSaVioR Sep 7, 2022
f2a75f1
WIP remove unnecessary interface/implementation changes on GroupState…
HeartSaVioR Sep 13, 2022
3e5f5d4
WIP refine out some code
HeartSaVioR Sep 13, 2022
4e34d29
WIP fix scalastyle
HeartSaVioR Sep 13, 2022
50e743e
WIP remove obsolete class
HeartSaVioR Sep 13, 2022
d22d7db
WIP remove the temp fix
HeartSaVioR Sep 13, 2022
e60408f
remove unused code
HeartSaVioR Sep 14, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP avoid adding additional empty row for state, empty row will be ad…
…ded only when there is no data
  • Loading branch information
HeartSaVioR committed Sep 14, 2022
commit 765f4d3519da7c0fe4224742a00a03742af54611
82 changes: 62 additions & 20 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.types import StringType, StructType, BinaryType, StructField
from pyspark.sql.types import StringType, StructType, BinaryType, StructField, BooleanType


class SpecialLengths:
END_OF_DATA_SECTION = -1
Expand Down Expand Up @@ -248,6 +249,7 @@ def create_array(s, t):

arrs = []
for s, t in series:
print("==== <_create_batch> s: %s t: %s" % (s, t, ), file=sys.stderr)
if t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
raise ValueError(
Expand Down Expand Up @@ -390,22 +392,52 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema
StructField('properties', StringType()),
StructField('keyRowAsUnsafe', BinaryType()),
StructField('object', BinaryType()),
StructField('isEmptyData', BooleanType()),
])

self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type)

def arrow_to_pandas(self, arrow_column):
return super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)

def load_stream(self, stream):
import pyarrow as pa
import json
from pyspark.sql.streaming.state import GroupStateImpl

batches = ArrowStreamPandasUDFSerializer.load_stream(self, stream)
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)

for batch in batches:
# FIXME: can we leverage schema here? doesn't work well so...
state_info_col = batch[-1][0]
print("=== <load_stream> batch: %s type(batch): %s" % (batch, type(batch), ), file=sys.stderr)

batch_schema = batch.schema

print("=== <load_stream> batch_schema: %s type(batch_schema): %s" % (batch_schema, type(batch_schema), ), file=sys.stderr)

batch_columns = batch.columns
data_columns = batch_columns[0:-1]
state_column = batch_columns[-1]

print("=== <load_stream> data_columns: %s state_column: %s" % (data_columns, state_column, ), file=sys.stderr)

data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)])
state_schema = pa.schema([batch_schema[-1], ])

print("=== <load_stream> data_schema: %s state_schema: %s" % (data_schema, state_schema, ), file=sys.stderr)

data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema)
state_batch = pa.RecordBatch.from_arrays([state_column, ], schema=state_schema)

print("=== <load_stream> data_batch: %s state_batch: %s" % (data_batch, state_batch, ), file=sys.stderr)

data_arrow = pa.Table.from_batches([data_batch]).itercolumns()
state_arrow = pa.Table.from_batches([state_batch]).itercolumns()

print("=== <load_stream> data_arrow_columns: %s state_arrow_columns: %s" % (data_arrow, state_arrow, ), file=sys.stderr)

data_pandas = [self.arrow_to_pandas(c) for c in data_arrow]
state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0]

print("=== <load_stream> data_pandas: %s type(data_pandas): %s state_pandas: %s type(state_pandas): %s" % (data_pandas, type(data_pandas), state_pandas, type(state_pandas), ), file=sys.stderr)

state_info_col = state_pandas.iloc[0]

state_info_col_properties = state_info_col['properties']
state_info_col_key_row = state_info_col['keyRowAsUnsafe']
Expand All @@ -421,10 +453,10 @@ def load_stream(self, stream):
state = GroupStateImpl(keyAsUnsafe=state_info_col_key_row,
valueSchema=self.state_object_schema, **state_properties)

state_column_dropped_series = batch[0:-1]
first_row_dropped_series = [x.iloc[1:].reset_index(drop=True) for x in state_column_dropped_series]
print("=== <load_stream> data_pandas: %s state: %s" % (data_pandas, state, ), file=sys.stderr)

# state info
yield (first_row_dropped_series, state, )
yield (data_pandas, state, )

def dump_stream(self, iterator, stream):
"""
Expand All @@ -435,36 +467,46 @@ def dump_stream(self, iterator, stream):

def init_stream_yield_batches():
import pandas as pd
import pyarrow as pa

should_write_start_length = True
for data in iterator:
packaged_result = data[0]

pdf = packaged_result[0][0].reset_index(drop=True)
pdf = packaged_result[0][0]
state = packaged_result[0][-1]
return_schema = packaged_result[1]

new_empty_row = pd.DataFrame(dict.fromkeys(pdf.columns), index=[0])
# FIXME: arrow type to pandas type
# FIXME: probably also need to check columns to validate?

print("==== <init_stream_yield_batches> pdf: %s len(pdf): %s" % (pdf, len(pdf), ), file=sys.stderr)

# Concatenate new_row with df
pdf_with_empty_row = pd.concat([new_empty_row, pdf[:]], axis=0).reset_index(drop=True)
empty_data = len(pdf) == 0
if empty_data:
# if returned DataFrame is empty with no column information, just create a new
# DataFrame with empty row with column information
pdf = pd.DataFrame(dict.fromkeys(pa.schema(return_schema).names), index=[0])

print("==== <init_stream_yield_batches> pdf: %s state: %s return_schema: %s" % (pdf, state, return_schema, ), file=sys.stderr)

state_properties = state.json().encode("utf-8")
state_key_row_as_binary = state._keyAsUnsafe
state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value))

len_pdf = len(pdf)
none_array = [None, ] * len_pdf
state_dict = {
'properties': [state_properties, ] + none_array,
'keyRowAsUnsafe': [state_key_row_as_binary, ] + none_array,
'object': [state_object, ] + none_array,
'properties': [state_properties, ],
'keyRowAsUnsafe': [state_key_row_as_binary, ],
'object': [state_object, ],
'isEmptyData': [empty_data, ],
}

state_pdf = pd.DataFrame.from_dict(state_dict)

print("==== <init_stream_yield_batches> pdf: %s return_schema: %s state_pdf: %s result_state arrow_schema: %s" % (pdf, return_schema, state_pdf, self.result_state_pdf_arrow_type, ), file=sys.stderr)

batch = self._create_batch([
(pdf_with_empty_row, return_schema),
(pdf, return_schema),
(state_pdf, self.result_state_pdf_arrow_type)])

if should_write_start_length:
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,11 @@ def wrap_grouped_map_pandas_udf_with_state(f, return_type):
def wrapped(key_series, value_series, state):
import pandas as pd

key = tuple(s.head(1).at[0] for s in key_series)
print("=== <wrapped> key_series: %s value_series: %s state: %s" % (key_series, value_series, state, ), file=sys.stderr)

key = tuple(s[0] for s in key_series)
print("=== <wrapped> key: %s" % (key, ), file=sys.stderr)

if state.hasTimedOut:
# Timeout processing pass empty iterator. Here we return an empty DataFrame instead.
result = f(key, pd.DataFrame(columns=pd.concat(value_series, axis=1).columns), state)
Expand Down
39 changes: 39 additions & 0 deletions sql/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,45 @@
</plugin>

<!-- FIXME: temporarily stop checking checkstyle -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
<version>3.1.2</version>
<configuration>
<skip>true</skip>
<failOnViolation>false</failOnViolation>
<includeTestSourceDirectory>true</includeTestSourceDirectory>
<sourceDirectories>
<directory>${basedir}/src/main/java</directory>
<directory>${basedir}/src/main/scala</directory>
</sourceDirectories>
<testSourceDirectories>
<directory>${basedir}/src/test/java</directory>
</testSourceDirectories>
<configLocation>dev/checkstyle.xml</configLocation>
<outputFile>${basedir}/target/checkstyle-output.xml</outputFile>
<inputEncoding>${project.build.sourceEncoding}</inputEncoding>
<outputEncoding>${project.reporting.outputEncoding}</outputEncoding>
</configuration>
<dependencies>
<dependency>
<!--
If you are changing the dependency setting for checkstyle plugin,
please check project/plugins.sbt too.
-->
<groupId>com.puppycrawl.tools</groupId>
<artifactId>checkstyle</artifactId>
<version>8.43</version>
</dependency>
</dependencies>
<executions>
<execution>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ object ArrowWriter {
new ArrowWriter(root, children.toArray)
}

private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
val field = vector.getField()
(ArrowUtils.fromArrowField(field), vector) match {
case (BooleanType, vector: BitVector) => new BooleanWriter(vector)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@ import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.{SparkEnv, TaskContext}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python._
import org.apache.spark.sql.Row
import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.arrow.ArrowWriter
import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter
import org.apache.spark.sql.execution.streaming.GroupStateImpl
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -81,6 +82,8 @@ class ArrowPythonRunnerWithState(
)
)

logWarning(s"DEBUG: schemaWithState: ${schemaWithState}")

val stateRowSerializer = stateEncoder.createSerializer()
val stateRowDeserializer = stateEncoder.createDeserializer()

Expand Down Expand Up @@ -122,37 +125,55 @@ class ArrowPythonRunnerWithState(

protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
val arrowSchema = ArrowUtils.toArrowSchema(schemaWithState, timeZoneId)

logWarning(s"DEBUG: arrowSchema: ${arrowSchema}")

val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdout writer for $pythonExec", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)

Utils.tryWithSafeFinally {
val nullDataRow = new GenericInternalRow(Array.fill(inputSchema.length)(null: Any))
val nullStateInfoRow = new GenericInternalRow(Array.fill(1)(null: Any))
val arrowWriterForData = {
val children = root.getFieldVectors().asScala.dropRight(1).map { vector =>
vector.allocateNew()
createFieldWriter(vector)
}

new ArrowWriter(root, children.toArray)
}
val arrowWriterForState = {
val children = root.getFieldVectors().asScala.takeRight(1).map { vector =>
vector.allocateNew()
createFieldWriter(vector)
}
new ArrowWriter(root, children.toArray)
}

val arrowWriter = ArrowWriter.create(root)
val writer = new ArrowStreamWriter(root, null, dataOut)
writer.start()

val joinedRow = new JoinedRow
while (inputIterator.hasNext) {
val (keyRow, groupState, dataIter) = inputIterator.next()

assert(dataIter.hasNext, "should have at least one data row!")

// Provide state info row in the first row
val stateInfoRow = buildStateInfoRow(keyRow, groupState)
joinedRow.withLeft(nullDataRow).withRight(stateInfoRow)
arrowWriter.write(joinedRow)
arrowWriterForState.write(stateInfoRow)

// Continue providing remaining data rows
while (dataIter.hasNext) {
val dataRow = dataIter.next()
joinedRow.withLeft(dataRow).withRight(nullStateInfoRow)
arrowWriter.write(joinedRow)
arrowWriterForData.write(dataRow)
}

arrowWriter.finish()
// DO NOT CHANGE THE ORDER OF FINISH! We are picking up the number of rows from data
// side, as we know there is at least one data row.
arrowWriterForState.finish()
arrowWriterForData.finish()
writer.writeBatch()
arrowWriter.reset()
arrowWriterForState.reset()
arrowWriterForData.reset()
}
// end writes footer to the output stream and doesn't clean any resources.
// It could throw exception if the output stream is closed, so it should be
Expand Down Expand Up @@ -267,26 +288,15 @@ class ArrowPythonRunnerWithState(

val rowForStateInfo = flattenedBatchForState.getRow(0)

// UDF returns a StructType column in ColumnarBatch, select the children here
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
val outputVectors = schema(0).dataType.asInstanceOf[StructType]
.indices.map(structVector.getChild)
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
flattenedBatch.setNumRows(batch.numRows())

val rowIterator = flattenedBatch.rowIterator.asScala
// drop first row as it's reserved for state
assert(rowIterator.hasNext)
rowIterator.next()

// FIXME: we rely on known schema for state info, but would we want to access this by
// column name?
// Received state information does not need schemas - this class already knows them.
/*
Array(
StructField("properties", StringType),
StructField("keyRowAsUnsafe", BinaryType),
StructField("object", BinaryType)
StructField("object", BinaryType),
StructField('isEmptyData', BooleanType)
)
*/
implicit val formats = org.json4s.DefaultFormats
Expand All @@ -301,10 +311,26 @@ class ArrowPythonRunnerWithState(
val pickledRow = rowForStateInfo.getBinary(2)
Some(PythonSQLUtils.toJVMRow(pickledRow, stateSchema, stateRowDeserializer))
}
val isEmptyData = rowForStateInfo.getBoolean(3)

val newGroupState = GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson)

(keyRowAsUnsafe, newGroupState, rowIterator.map(unsafeProjForData))
val rowIterator = if (isEmptyData) {
logWarning("DEBUG: no data is available")
Iterator.empty
} else {
// UDF returns a StructType column in ColumnarBatch, select the children here
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
val outputVectors = schema(0).dataType.asInstanceOf[StructType]
.indices.map(structVector.getChild)
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
flattenedBatch.setNumRows(batch.numRows())

val rowIterator = flattenedBatch.rowIterator.asScala
rowIterator.map(unsafeProjForData)
}

(keyRowAsUnsafe, newGroupState, rowIterator)
}
}
}
Expand Down