Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
512958b
Basework
HeartSaVioR Jul 12, 2022
d36373b
Add Python implementation
HyukjinKwon Jul 14, 2022
f754fd9
Reorder key attributes from deduplicated data attributes
HyukjinKwon Jul 27, 2022
5194e0c
Apply suggestions from code review
HyukjinKwon Jul 27, 2022
1301ee5
Refactoring a bit to respect the column order
HyukjinKwon Aug 11, 2022
135a826
WIP Changes to execute in pipelined manner
HeartSaVioR Aug 15, 2022
9282e5c
WIP further optimization
HeartSaVioR Aug 18, 2022
a792c98
WIP comments for more tunes
HeartSaVioR Aug 18, 2022
27e7af9
WIP further tune...
HeartSaVioR Aug 18, 2022
04a6b98
WIP done more tune! didn't do any of pandas/arrow side tunes
HeartSaVioR Aug 18, 2022
765f4d3
WIP avoid adding additional empty row for state, empty row will be ad…
HeartSaVioR Aug 19, 2022
9e11225
WIP remove debug log
HeartSaVioR Aug 19, 2022
f33d978
WIP hack around to see the possibility of perf gain on binpacking
HeartSaVioR Aug 27, 2022
8604fdf
WIP proper work to apply binpacking on python worker -> executor
HeartSaVioR Aug 27, 2022
0d024e0
WIP fix silly bug
HeartSaVioR Aug 27, 2022
43c623b
WIP another silly bugfix on migration
HeartSaVioR Aug 27, 2022
af1725a
WIP apply binpacking for executor -> python worker as well
HeartSaVioR Aug 27, 2022
31e9687
WIP fix silly bug
HeartSaVioR Aug 27, 2022
cad77a2
WIP fix another silly bug
HeartSaVioR Aug 27, 2022
c3da996
WIP batching per specified size, with sampling
HeartSaVioR Aug 29, 2022
cfb2780
WIP introduce DBR-only change
HeartSaVioR Aug 29, 2022
228b140
WIP debugging now...
HeartSaVioR Aug 29, 2022
ee4ed57
WIP still debugging... weirdness happened
HeartSaVioR Aug 30, 2022
4045ab3
WIP small fix
HeartSaVioR Aug 30, 2022
2d115ab
WIP fix a serious bug... make sure all columns in Arrow RecordBatch h…
HeartSaVioR Aug 30, 2022
3e7d785
WIP strengthen test
HeartSaVioR Aug 30, 2022
029dae7
WIP documenting the changes for pipelining and bin-packing... not yet…
HeartSaVioR Sep 2, 2022
d7ecaf9
WIP sync
HeartSaVioR Sep 2, 2022
6a6dd20
WIP start with is_last_chunk since it's easier to implement... severa…
HeartSaVioR Sep 2, 2022
5cfd59c
WIP adjust the test code to make test pass with multiple calls
HeartSaVioR Sep 2, 2022
63f8f87
WIP refactor a bit... just extract the abstract classes to explicit ones
HeartSaVioR Sep 5, 2022
6e772cd
WIP iterator of DatFrame done! updated tests and they all passed
HeartSaVioR Sep 5, 2022
00836b5
WIP FIX pyspark side test failure
HeartSaVioR Sep 6, 2022
5fdde94
WIP sort out codebase a bit
HeartSaVioR Sep 14, 2022
e7ad043
WIP no batch query support in applyInPandasWithState
HeartSaVioR Sep 6, 2022
5070b81
WIP address some missed things
HeartSaVioR Sep 6, 2022
1b919b8
WIP remove comments which are obsolete or won't be addressed
HeartSaVioR Sep 7, 2022
198fc17
WIP change the return type of user function to Iterator[DataFrame]
HeartSaVioR Sep 7, 2022
f2a75f1
WIP remove unnecessary interface/implementation changes on GroupState…
HeartSaVioR Sep 13, 2022
3e5f5d4
WIP refine out some code
HeartSaVioR Sep 13, 2022
4e34d29
WIP fix scalastyle
HeartSaVioR Sep 13, 2022
50e743e
WIP remove obsolete class
HeartSaVioR Sep 13, 2022
d22d7db
WIP remove the temp fix
HeartSaVioR Sep 13, 2022
e60408f
remove unused code
HeartSaVioR Sep 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP refine out some code
  • Loading branch information
HeartSaVioR committed Sep 14, 2022
commit 3e5f5d4e878e5a9a1e71490191cf98751c273739
61 changes: 33 additions & 28 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def gen_data_and_state(batches):

data_state_generator = gen_data_and_state(batches)

# state will be same object for same grouping key
for state, data in groupby(data_state_generator, key=lambda x: x[1]):
yield (data, state,)

Expand All @@ -486,6 +487,11 @@ def dump_stream(self, iterator, stream):
"""

def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_data_cnt):
"""
Arrow RecordBatch requires all columns to have all same number of rows.
Insert empty data for state/data with less elements to compensate.
"""

import pandas as pd
import pyarrow as pa

Expand Down Expand Up @@ -525,9 +531,6 @@ def init_stream_yield_batches():
state_data_cnt = 0

sampled_data_size_per_row = 0
sampled_state_size = 0
# FIXME: sample with empty state size separately?
sampled_empty_state_size = 0

last_purged_time_ns = time.time_ns()

Expand All @@ -539,15 +542,37 @@ def init_stream_yield_batches():
# this won't change across batches
return_schema = packaged_result[1]

# FIXME: arrow type to pandas type
# FIXME: probably also need to check columns to validate?

for pdf in pdf_iter:
# FIXME: probably need to reduce down the scope of record batch to this?
if len(pdf) > 0:
pdf_data_cnt += len(pdf)
pdfs.append(pdf)

if sampled_data_size_per_row == 0 and \
pdf_data_cnt > self.minDataCountForSample:
memory_usages = [p.memory_usage(deep=True).sum() for p in pdfs]
sampled_data_size_per_row = sum(memory_usages) / pdf_data_cnt

# This effectively works after the sampling has completed, size we multiply by 0
# if the sampling is still in progress.
batch_over_limit_on_size = (sampled_data_size_per_row * pdf_data_cnt) >= \
self.softLimitBytesPerBatch

if batch_over_limit_on_size:
batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema,
state_pdfs, state_data_cnt)

pdfs = []
state_pdfs = []
pdf_data_cnt = 0
state_data_cnt = 0
last_purged_time_ns = time.time_ns()

if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False

yield batch

# pick up state for only last chunk as state should have been updated so far
state_properties = state.json().encode("utf-8")
state_key_row_as_binary = state._keyAsUnsafe
Expand All @@ -566,24 +591,10 @@ def init_stream_yield_batches():
state_pdfs.append(state_pdf)
state_data_cnt += 1

# FIXME: threshold of sample data
if sampled_data_size_per_row == 0 and pdf_data_cnt > self.minDataCountForSample:
memory_usages = [p.memory_usage(deep=True).sum() for p in pdfs]
sampled_data_size_per_row = sum(memory_usages) / pdf_data_cnt

# FIXME: threshold of sample data
if sampled_state_size == 0 and state_data_cnt > self.minDataCountForSample:
memory_usages = [p.memory_usage(deep=True).sum() for p in state_pdfs]
sampled_state_size = sum(memory_usages) / state_data_cnt

# This effectively works after the sampling has completed, size we multiply by 0
# if the sampling is still in progress.
batch_over_limit_on_size = (sampled_data_size_per_row * pdf_data_cnt) + \
(sampled_state_size * state_data_cnt) >= self.softLimitBytesPerBatch
cur_time_ns = time.time_ns()
is_timed_out_on_purge = ((cur_time_ns - last_purged_time_ns) // 1000000) >= \
self.softTimeoutMillisPurgeBatch
if batch_over_limit_on_size or is_timed_out_on_purge:
if is_timed_out_on_purge:
batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema,
state_pdfs, state_data_cnt)

Expand All @@ -604,14 +615,8 @@ def init_stream_yield_batches():
batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema,
state_pdfs, state_data_cnt)

pdfs = []
state_pdfs = []
pdf_data_cnt = 0
state_data_cnt = 0

if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False

yield batch

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2707,21 +2707,38 @@ object SQLConf {

val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH =
buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch")
// FIXME: doc
.internal()
.doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " +
"records that can be written to a single ArrowRecordBatch in memory. This is used to " +
"restrict the amount of memory being used to materialize the data in both executor and " +
"Python worker. The accumulated size of records are calculated via sampling a set of " +
"records. Splitting the ArrowRecordBatch is performed per record, so unless a record " +
"is quite huge, the size of constructed ArrowRecordBatch will be around the " +
"configured value.")
.version("3.4.0")
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("64MB")

val MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE =
buildConf("spark.sql.execution.applyInPandasWithState.minDataCountForSample")
// FIXME: doc
.internal()
.doc("When using applyInPandasWithState, specify the minimum number of records to sample " +
"the size of record. The size being retrieved from sampling will be used to estimate " +
"the accumulated size of records. Note that limiting by size does not work if the " +
"number of records are less than the configured value. For such case, ArrowRecordBatch " +
"will only be split for soft timeout.")
.version("3.4.0")
.intConf
.createWithDefault(100)

val MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH =
buildConf("spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch")
// FIXME: doc
.internal()
.doc("When using applyInPandasWithState, specify the soft timeout for purging the " +
"ArrowRecordBatch. If batching records exceeds the timeout, Spark will force splitting " +
"the ArrowRecordBatch regardless of estimated size. This config ensures the receiver " +
"of data (both executor and Python worker) to not wait indefinitely for sender to " +
"complete the ArrowRecordBatch, which may hurt both throughput and latency.")
.version("3.4.0")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("100ms")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.python

import java.io._

import scala.collection.JavaConverters._

import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.json4s._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.api.python._
import org.apache.spark.sql.Row
import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
import org.apache.spark.sql.execution.streaming.GroupStateImpl
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}


/**
* [[ArrowPythonRunner]] with [[org.apache.spark.sql.streaming.GroupState]].
*/
class ApplyInPandasWithStatePythonRunner(
funcs: Seq[ChainedPythonFunctions],
evalType: Int,
argOffsets: Array[Array[Int]],
inputSchema: StructType,
override protected val timeZoneId: String,
initialWorkerConf: Map[String, String],
stateEncoder: ExpressionEncoder[Row],
keySchema: StructType,
valueSchema: StructType,
stateValueSchema: StructType,
softLimitBytesPerBatch: Long,
minDataCountForSample: Int,
softTimeoutMillsPurgeBatch: Long)
extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
with PythonArrowInput[InType]
with PythonArrowOutput[OutType] {

override protected val schema: StructType = inputSchema.add("!__state__!", STATE_METADATA_SCHEMA)

override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
require(
bufferSize >= 4,
"Pandas execution requires more than 4 bytes. Please set higher buffer. " +
s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")

override protected val workerConf: Map[String, String] = initialWorkerConf +
(SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key ->
softLimitBytesPerBatch.toString) +
(SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE.key ->
minDataCountForSample.toString) +
(SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH.key ->
softTimeoutMillsPurgeBatch.toString)

private val stateRowDeserializer = stateEncoder.createDeserializer()

override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
super.handleMetadataBeforeExec(stream)
// Also write the schema for state value
PythonRDD.writeUTF(stateValueSchema.json, stream)
}

protected def writeIteratorToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
dataOut: DataOutputStream,
inputIterator: Iterator[InType]): Unit = {
val w = new ApplyInPandasWithStateWriter(root, writer, softLimitBytesPerBatch,
minDataCountForSample, softTimeoutMillsPurgeBatch)

while (inputIterator.hasNext) {
val (keyRow, groupState, dataIter) = inputIterator.next()
assert(dataIter.hasNext, "should have at least one data row!")
w.startNewGroup(keyRow, groupState)

while (dataIter.hasNext) {
val dataRow = dataIter.next()
w.writeRow(dataRow)
}

w.finalizeGroup()
}

w.finalizeData()
}

protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OutType = {
// This should at least have one row for state. Also, we ensure that all columns across
// data and state metadata have same number of rows, which is required by Arrow record
// batch.
assert(batch.numRows() > 0)
assert(schema.length == 2)

def getColumnarBatchForStructTypeColumn(
batch: ColumnarBatch,
ordinal: Int,
expectedType: StructType): ColumnarBatch = {
// UDF returns a StructType column in ColumnarBatch, select the children here
val structVector = batch.column(ordinal).asInstanceOf[ArrowColumnVector]
val dataType = schema(ordinal).dataType.asInstanceOf[StructType]
assert(dataType.sameType(expectedType))

val outputVectors = dataType.indices.map(structVector.getChild)
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
flattenedBatch.setNumRows(batch.numRows())

flattenedBatch
}

def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = {
val dataBatch = getColumnarBatchForStructTypeColumn(batch, 0, valueSchema)
dataBatch.rowIterator.asScala.flatMap { row =>
if (row.isNullAt(0)) {
// The entire row in record batch seems to be for state metadata.
None
} else {
Some(row)
}
}
}

def constructIterForState(batch: ColumnarBatch): Iterator[OutTypeForState] = {
val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 1,
STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER)

stateMetadataBatch.rowIterator().asScala.flatMap { row =>
implicit val formats = org.json4s.DefaultFormats

if (row.isNullAt(0)) {
// The entire row in record batch seems to be for data.
None
} else {
// NOTE: See StateReaderIterator.STATE_METADATA_SCHEMA for the schema.
val propertiesAsJson = parse(row.getUTF8String(0).toString)
val keyRowAsUnsafeAsBinary = row.getBinary(1)
val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length)
keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length)
val maybeObjectRow = if (row.isNullAt(2)) {
None
} else {
val pickledStateValue = row.getBinary(2)
Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema,
stateRowDeserializer))
}
val oldTimeoutTimestamp = row.getLong(3)

Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson),
oldTimeoutTimestamp))
}
}
}

(constructIterForState(batch), constructIterForData(batch))
}
}

object ApplyInPandasWithStatePythonRunner {
type InType = (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow])
type OutTypeForState = (UnsafeRow, GroupStateImpl[Row], Long)
type OutType = (Iterator[OutTypeForState], Iterator[InternalRow])

val STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER: StructType = StructType(
Array(
StructField("properties", StringType),
StructField("keyRowAsUnsafe", BinaryType),
StructField("object", BinaryType),
StructField("oldTimeoutTimestamp", LongType),
)
)
}
Loading