Skip to content

Commit d430609

Browse files
zhengruifenghuangxiaopingRD
authored andcommitted
[SPARK-53610][PYTHON] Limit Arrow batch sizes in CoGrouped applyInPandas and applyInArrow
### What changes were proposed in this pull request? Limit Arrow batch sizes in CoGrouped applyInPandas and applyInArrow ### Why are the changes needed? to mitigate JVM side OOM ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? added test ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#52700 from zhengruifeng/arrow_batch_cogroup. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 16927c5 commit d430609

File tree

5 files changed

+222
-83
lines changed

5 files changed

+222
-83
lines changed

python/pyspark/sql/pandas/serializers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,16 +1302,16 @@ def load_stream(self, stream):
13021302
dataframes_in_group = read_int(stream)
13031303

13041304
if dataframes_in_group == 2:
1305-
batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
1306-
batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
1305+
batches1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
1306+
batches2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
13071307
yield (
13081308
[
13091309
self.arrow_to_pandas(c, i)
1310-
for i, c in enumerate(pa.Table.from_batches(batch1).itercolumns())
1310+
for i, c in enumerate(pa.Table.from_batches(batches1).itercolumns())
13111311
],
13121312
[
13131313
self.arrow_to_pandas(c, i)
1314-
for i, c in enumerate(pa.Table.from_batches(batch2).itercolumns())
1314+
for i, c in enumerate(pa.Table.from_batches(batches2).itercolumns())
13151315
],
13161316
)
13171317

python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pyspark.errors import PythonException
2222
from pyspark.sql import Row
23-
from pyspark.sql.functions import col
23+
from pyspark.sql import functions as sf
2424
from pyspark.testing.sqlutils import (
2525
ReusedSQLTestCase,
2626
have_pyarrow,
@@ -39,16 +39,16 @@
3939
class CogroupedMapInArrowTestsMixin:
4040
@property
4141
def left(self):
42-
return self.spark.range(0, 10, 2, 3).withColumn("v", col("id") * 10)
42+
return self.spark.range(0, 10, 2, 3).withColumn("v", sf.col("id") * 10)
4343

4444
@property
4545
def right(self):
46-
return self.spark.range(0, 10, 3, 3).withColumn("v", col("id") * 10)
46+
return self.spark.range(0, 10, 3, 3).withColumn("v", sf.col("id") * 10)
4747

4848
@property
4949
def cogrouped(self):
50-
grouped_left_df = self.left.groupBy((col("id") / 4).cast("int"))
51-
grouped_right_df = self.right.groupBy((col("id") / 4).cast("int"))
50+
grouped_left_df = self.left.groupBy((sf.col("id") / 4).cast("int"))
51+
grouped_right_df = self.right.groupBy((sf.col("id") / 4).cast("int"))
5252
return grouped_left_df.cogroup(grouped_right_df)
5353

5454
@staticmethod
@@ -309,6 +309,59 @@ def arrow_func(key, left, right):
309309

310310
self.assertEqual(df2.join(df2).count(), 1)
311311

312+
def test_arrow_batch_slicing(self):
313+
df1 = self.spark.range(10000000).select(
314+
(sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
315+
)
316+
cols = {f"col_{i}": sf.col("v") + i for i in range(10)}
317+
df1 = df1.withColumns(cols)
318+
319+
df2 = self.spark.range(100000).select(
320+
(sf.col("id") % 4).alias("key"), sf.col("id").alias("v")
321+
)
322+
cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
323+
df2 = df2.withColumns(cols)
324+
325+
def summarize(key, left, right):
326+
assert len(left) == 10000000 / 2 or len(left) == 0, len(left)
327+
assert len(right) == 100000 / 4, len(right)
328+
return pa.Table.from_pydict(
329+
{
330+
"key": [key[0].as_py()],
331+
"left_rows": [left.num_rows],
332+
"left_columns": [left.num_columns],
333+
"right_rows": [right.num_rows],
334+
"right_columns": [right.num_columns],
335+
}
336+
)
337+
338+
schema = "key long, left_rows long, left_columns long, right_rows long, right_columns long"
339+
340+
expected = [
341+
Row(key=0, left_rows=5000000, left_columns=12, right_rows=25000, right_columns=22),
342+
Row(key=1, left_rows=5000000, left_columns=12, right_rows=25000, right_columns=22),
343+
Row(key=2, left_rows=0, left_columns=12, right_rows=25000, right_columns=22),
344+
Row(key=3, left_rows=0, left_columns=12, right_rows=25000, right_columns=22),
345+
]
346+
347+
for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
348+
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
349+
with self.sql_conf(
350+
{
351+
"spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
352+
"spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
353+
}
354+
):
355+
result = (
356+
df1.groupby("key")
357+
.cogroup(df2.groupby("key"))
358+
.applyInArrow(summarize, schema=schema)
359+
.sort("key")
360+
.collect()
361+
)
362+
363+
self.assertEqual(expected, result)
364+
312365

313366
class CogroupedMapInArrowTests(CogroupedMapInArrowTestsMixin, ReusedSQLTestCase):
314367
@classmethod

python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import unittest
1919
from typing import cast
2020

21-
from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf, sum
21+
from pyspark.sql import functions as sf
22+
from pyspark.sql.functions import pandas_udf, udf
2223
from pyspark.sql.types import (
2324
ArrayType,
2425
DoubleType,
@@ -55,35 +56,35 @@ class CogroupedApplyInPandasTestsMixin:
5556
def data1(self):
5657
return (
5758
self.spark.range(10)
58-
.withColumn("ks", array([lit(i) for i in range(20, 30)]))
59-
.withColumn("k", explode(col("ks")))
60-
.withColumn("v", col("k") * 10)
59+
.withColumn("ks", sf.array([sf.lit(i) for i in range(20, 30)]))
60+
.withColumn("k", sf.explode(sf.col("ks")))
61+
.withColumn("v", sf.col("k") * 10)
6162
.drop("ks")
6263
)
6364

6465
@property
6566
def data2(self):
6667
return (
6768
self.spark.range(10)
68-
.withColumn("ks", array([lit(i) for i in range(20, 30)]))
69-
.withColumn("k", explode(col("ks")))
70-
.withColumn("v2", col("k") * 100)
69+
.withColumn("ks", sf.array([sf.lit(i) for i in range(20, 30)]))
70+
.withColumn("k", sf.explode(sf.col("ks")))
71+
.withColumn("v2", sf.col("k") * 100)
7172
.drop("ks")
7273
)
7374

7475
def test_simple(self):
7576
self._test_merge(self.data1, self.data2)
7677

7778
def test_left_group_empty(self):
78-
left = self.data1.where(col("id") % 2 == 0)
79+
left = self.data1.where(sf.col("id") % 2 == 0)
7980
self._test_merge(left, self.data2)
8081

8182
def test_right_group_empty(self):
82-
right = self.data2.where(col("id") % 2 == 0)
83+
right = self.data2.where(sf.col("id") % 2 == 0)
8384
self._test_merge(self.data1, right)
8485

8586
def test_different_schemas(self):
86-
right = self.data2.withColumn("v3", lit("a"))
87+
right = self.data2.withColumn("v3", sf.lit("a"))
8788
self._test_merge(
8889
self.data1, right, output_schema="id long, k int, v int, v2 int, v3 string"
8990
)
@@ -116,9 +117,9 @@ def test_complex_group_by(self):
116117

117118
right = pd.DataFrame.from_dict({"id": [11, 12, 13], "k": [5, 6, 7], "v2": [90, 100, 110]})
118119

119-
left_gdf = self.spark.createDataFrame(left).groupby(col("id") % 2 == 0)
120+
left_gdf = self.spark.createDataFrame(left).groupby(sf.col("id") % 2 == 0)
120121

121-
right_gdf = self.spark.createDataFrame(right).groupby(col("id") % 2 == 0)
122+
right_gdf = self.spark.createDataFrame(right).groupby(sf.col("id") % 2 == 0)
122123

123124
def merge_pandas(lft, rgt):
124125
return pd.merge(lft[["k", "v"]], rgt[["k", "v2"]], on=["k"])
@@ -354,20 +355,20 @@ def test_with_key_right(self):
354355
self._test_with_key(self.data1, self.data1, isLeft=False)
355356

356357
def test_with_key_left_group_empty(self):
357-
left = self.data1.where(col("id") % 2 == 0)
358+
left = self.data1.where(sf.col("id") % 2 == 0)
358359
self._test_with_key(left, self.data1, isLeft=True)
359360

360361
def test_with_key_right_group_empty(self):
361-
right = self.data1.where(col("id") % 2 == 0)
362+
right = self.data1.where(sf.col("id") % 2 == 0)
362363
self._test_with_key(self.data1, right, isLeft=False)
363364

364365
def test_with_key_complex(self):
365366
def left_assign_key(key, lft, _):
366367
return lft.assign(key=key[0])
367368

368369
result = (
369-
self.data1.groupby(col("id") % 2 == 0)
370-
.cogroup(self.data2.groupby(col("id") % 2 == 0))
370+
self.data1.groupby(sf.col("id") % 2 == 0)
371+
.cogroup(self.data2.groupby(sf.col("id") % 2 == 0))
371372
.applyInPandas(left_assign_key, "id long, k int, v int, key boolean")
372373
.sort(["id", "k"])
373374
.toPandas()
@@ -456,7 +457,9 @@ def test_with_window_function(self):
456457
left_df = df.withColumnRenamed("value", "left").repartition(parts).cache()
457458
# SPARK-42132: this bug requires us to alias all columns from df here
458459
right_df = (
459-
df.select(col("id").alias("id"), col("day").alias("day"), col("value").alias("right"))
460+
df.select(
461+
sf.col("id").alias("id"), sf.col("day").alias("day"), sf.col("value").alias("right")
462+
)
460463
.repartition(parts)
461464
.cache()
462465
)
@@ -465,9 +468,9 @@ def test_with_window_function(self):
465468
window = Window.partitionBy("day", "id")
466469

467470
left_grouped_df = left_df.groupBy("id", "day")
468-
right_grouped_df = right_df.withColumn("day_sum", sum(col("day")).over(window)).groupBy(
469-
"id", "day"
470-
)
471+
right_grouped_df = right_df.withColumn(
472+
"day_sum", sf.sum(sf.col("day")).over(window)
473+
).groupBy("id", "day")
471474

472475
def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
473476
return pd.DataFrame(
@@ -653,6 +656,59 @@ def __test_merge_error(
653656
with self.assertRaisesRegex(errorClass, error_message_regex):
654657
self.__test_merge(left, right, by, fn, output_schema)
655658

659+
def test_arrow_batch_slicing(self):
660+
df1 = self.spark.range(10000000).select(
661+
(sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
662+
)
663+
cols = {f"col_{i}": sf.col("v") + i for i in range(10)}
664+
df1 = df1.withColumns(cols)
665+
666+
df2 = self.spark.range(100000).select(
667+
(sf.col("id") % 4).alias("key"), sf.col("id").alias("v")
668+
)
669+
cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
670+
df2 = df2.withColumns(cols)
671+
672+
def summarize(key, left, right):
673+
assert len(left) == 10000000 / 2 or len(left) == 0, len(left)
674+
assert len(right) == 100000 / 4, len(right)
675+
return pd.DataFrame(
676+
{
677+
"key": [key[0]],
678+
"left_rows": [len(left)],
679+
"left_columns": [len(left.columns)],
680+
"right_rows": [len(right)],
681+
"right_columns": [len(right.columns)],
682+
}
683+
)
684+
685+
schema = "key long, left_rows long, left_columns long, right_rows long, right_columns long"
686+
687+
expected = [
688+
Row(key=0, left_rows=5000000, left_columns=12, right_rows=25000, right_columns=22),
689+
Row(key=1, left_rows=5000000, left_columns=12, right_rows=25000, right_columns=22),
690+
Row(key=2, left_rows=0, left_columns=12, right_rows=25000, right_columns=22),
691+
Row(key=3, left_rows=0, left_columns=12, right_rows=25000, right_columns=22),
692+
]
693+
694+
for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
695+
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
696+
with self.sql_conf(
697+
{
698+
"spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
699+
"spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
700+
}
701+
):
702+
result = (
703+
df1.groupby("key")
704+
.cogroup(df2.groupby("key"))
705+
.applyInPandas(summarize, schema=schema)
706+
.sort("key")
707+
.collect()
708+
)
709+
710+
self.assertEqual(expected, result)
711+
656712

657713
class CogroupedApplyInPandasTests(CogroupedApplyInPandasTestsMixin, ReusedSQLTestCase):
658714
pass

0 commit comments

Comments
 (0)