1818import unittest
1919from typing import cast
2020
21- from pyspark .sql .functions import array , explode , col , lit , udf , pandas_udf , sum
21+ from pyspark .sql import functions as sf
22+ from pyspark .sql .functions import pandas_udf , udf
2223from pyspark .sql .types import (
2324 ArrayType ,
2425 DoubleType ,
@@ -55,35 +56,35 @@ class CogroupedApplyInPandasTestsMixin:
5556 def data1 (self ):
5657 return (
5758 self .spark .range (10 )
58- .withColumn ("ks" , array ([lit (i ) for i in range (20 , 30 )]))
59- .withColumn ("k" , explode (col ("ks" )))
60- .withColumn ("v" , col ("k" ) * 10 )
59+ .withColumn ("ks" , sf . array ([sf . lit (i ) for i in range (20 , 30 )]))
60+ .withColumn ("k" , sf . explode (sf . col ("ks" )))
61+ .withColumn ("v" , sf . col ("k" ) * 10 )
6162 .drop ("ks" )
6263 )
6364
6465 @property
6566 def data2 (self ):
6667 return (
6768 self .spark .range (10 )
68- .withColumn ("ks" , array ([lit (i ) for i in range (20 , 30 )]))
69- .withColumn ("k" , explode (col ("ks" )))
70- .withColumn ("v2" , col ("k" ) * 100 )
69+ .withColumn ("ks" , sf . array ([sf . lit (i ) for i in range (20 , 30 )]))
70+ .withColumn ("k" , sf . explode (sf . col ("ks" )))
71+ .withColumn ("v2" , sf . col ("k" ) * 100 )
7172 .drop ("ks" )
7273 )
7374
7475 def test_simple (self ):
7576 self ._test_merge (self .data1 , self .data2 )
7677
7778 def test_left_group_empty (self ):
78- left = self .data1 .where (col ("id" ) % 2 == 0 )
79+ left = self .data1 .where (sf . col ("id" ) % 2 == 0 )
7980 self ._test_merge (left , self .data2 )
8081
8182 def test_right_group_empty (self ):
82- right = self .data2 .where (col ("id" ) % 2 == 0 )
83+ right = self .data2 .where (sf . col ("id" ) % 2 == 0 )
8384 self ._test_merge (self .data1 , right )
8485
8586 def test_different_schemas (self ):
86- right = self .data2 .withColumn ("v3" , lit ("a" ))
87+ right = self .data2 .withColumn ("v3" , sf . lit ("a" ))
8788 self ._test_merge (
8889 self .data1 , right , output_schema = "id long, k int, v int, v2 int, v3 string"
8990 )
@@ -116,9 +117,9 @@ def test_complex_group_by(self):
116117
117118 right = pd .DataFrame .from_dict ({"id" : [11 , 12 , 13 ], "k" : [5 , 6 , 7 ], "v2" : [90 , 100 , 110 ]})
118119
119- left_gdf = self .spark .createDataFrame (left ).groupby (col ("id" ) % 2 == 0 )
120+ left_gdf = self .spark .createDataFrame (left ).groupby (sf . col ("id" ) % 2 == 0 )
120121
121- right_gdf = self .spark .createDataFrame (right ).groupby (col ("id" ) % 2 == 0 )
122+ right_gdf = self .spark .createDataFrame (right ).groupby (sf . col ("id" ) % 2 == 0 )
122123
123124 def merge_pandas (lft , rgt ):
124125 return pd .merge (lft [["k" , "v" ]], rgt [["k" , "v2" ]], on = ["k" ])
@@ -354,20 +355,20 @@ def test_with_key_right(self):
354355 self ._test_with_key (self .data1 , self .data1 , isLeft = False )
355356
356357 def test_with_key_left_group_empty (self ):
357- left = self .data1 .where (col ("id" ) % 2 == 0 )
358+ left = self .data1 .where (sf . col ("id" ) % 2 == 0 )
358359 self ._test_with_key (left , self .data1 , isLeft = True )
359360
360361 def test_with_key_right_group_empty (self ):
361- right = self .data1 .where (col ("id" ) % 2 == 0 )
362+ right = self .data1 .where (sf . col ("id" ) % 2 == 0 )
362363 self ._test_with_key (self .data1 , right , isLeft = False )
363364
364365 def test_with_key_complex (self ):
365366 def left_assign_key (key , lft , _ ):
366367 return lft .assign (key = key [0 ])
367368
368369 result = (
369- self .data1 .groupby (col ("id" ) % 2 == 0 )
370- .cogroup (self .data2 .groupby (col ("id" ) % 2 == 0 ))
370+ self .data1 .groupby (sf . col ("id" ) % 2 == 0 )
371+ .cogroup (self .data2 .groupby (sf . col ("id" ) % 2 == 0 ))
371372 .applyInPandas (left_assign_key , "id long, k int, v int, key boolean" )
372373 .sort (["id" , "k" ])
373374 .toPandas ()
@@ -456,7 +457,9 @@ def test_with_window_function(self):
456457 left_df = df .withColumnRenamed ("value" , "left" ).repartition (parts ).cache ()
457458 # SPARK-42132: this bug requires us to alias all columns from df here
458459 right_df = (
459- df .select (col ("id" ).alias ("id" ), col ("day" ).alias ("day" ), col ("value" ).alias ("right" ))
460+ df .select (
461+ sf .col ("id" ).alias ("id" ), sf .col ("day" ).alias ("day" ), sf .col ("value" ).alias ("right" )
462+ )
460463 .repartition (parts )
461464 .cache ()
462465 )
@@ -465,9 +468,9 @@ def test_with_window_function(self):
465468 window = Window .partitionBy ("day" , "id" )
466469
467470 left_grouped_df = left_df .groupBy ("id" , "day" )
468- right_grouped_df = right_df .withColumn ("day_sum" , sum ( col ( "day" )). over ( window )). groupBy (
469- "id " , "day"
470- )
471+ right_grouped_df = right_df .withColumn (
472+ "day_sum " , sf . sum ( sf . col ( "day" )). over ( window )
473+ ). groupBy ( "id" , "day" )
471474
472475 def cogroup (left : pd .DataFrame , right : pd .DataFrame ) -> pd .DataFrame :
473476 return pd .DataFrame (
@@ -653,6 +656,59 @@ def __test_merge_error(
653656 with self .assertRaisesRegex (errorClass , error_message_regex ):
654657 self .__test_merge (left , right , by , fn , output_schema )
655658
659+ def test_arrow_batch_slicing (self ):
660+ df1 = self .spark .range (10000000 ).select (
661+ (sf .col ("id" ) % 2 ).alias ("key" ), sf .col ("id" ).alias ("v" )
662+ )
663+ cols = {f"col_{ i } " : sf .col ("v" ) + i for i in range (10 )}
664+ df1 = df1 .withColumns (cols )
665+
666+ df2 = self .spark .range (100000 ).select (
667+ (sf .col ("id" ) % 4 ).alias ("key" ), sf .col ("id" ).alias ("v" )
668+ )
669+ cols = {f"col_{ i } " : sf .col ("v" ) + i for i in range (20 )}
670+ df2 = df2 .withColumns (cols )
671+
672+ def summarize (key , left , right ):
673+ assert len (left ) == 10000000 / 2 or len (left ) == 0 , len (left )
674+ assert len (right ) == 100000 / 4 , len (right )
675+ return pd .DataFrame (
676+ {
677+ "key" : [key [0 ]],
678+ "left_rows" : [len (left )],
679+ "left_columns" : [len (left .columns )],
680+ "right_rows" : [len (right )],
681+ "right_columns" : [len (right .columns )],
682+ }
683+ )
684+
685+ schema = "key long, left_rows long, left_columns long, right_rows long, right_columns long"
686+
687+ expected = [
688+ Row (key = 0 , left_rows = 5000000 , left_columns = 12 , right_rows = 25000 , right_columns = 22 ),
689+ Row (key = 1 , left_rows = 5000000 , left_columns = 12 , right_rows = 25000 , right_columns = 22 ),
690+ Row (key = 2 , left_rows = 0 , left_columns = 12 , right_rows = 25000 , right_columns = 22 ),
691+ Row (key = 3 , left_rows = 0 , left_columns = 12 , right_rows = 25000 , right_columns = 22 ),
692+ ]
693+
694+ for maxRecords , maxBytes in [(1000 , 2 ** 31 - 1 ), (0 , 1048576 ), (1000 , 1048576 )]:
695+ with self .subTest (maxRecords = maxRecords , maxBytes = maxBytes ):
696+ with self .sql_conf (
697+ {
698+ "spark.sql.execution.arrow.maxRecordsPerBatch" : maxRecords ,
699+ "spark.sql.execution.arrow.maxBytesPerBatch" : maxBytes ,
700+ }
701+ ):
702+ result = (
703+ df1 .groupby ("key" )
704+ .cogroup (df2 .groupby ("key" ))
705+ .applyInPandas (summarize , schema = schema )
706+ .sort ("key" )
707+ .collect ()
708+ )
709+
710+ self .assertEqual (expected , result )
711+
656712
657713class CogroupedApplyInPandasTests (CogroupedApplyInPandasTestsMixin , ReusedSQLTestCase ):
658714 pass
0 commit comments