@@ -3903,7 +3903,7 @@ def foo(df):
39033903 return df
39043904 with self .assertRaisesRegexp (ValueError , 'Invalid function' ):
39053905 @pandas_udf (returnType = 'k int, v double' , functionType = PandasUDFType .GROUPED_MAP )
3906- def foo (k , v ):
3906+ def foo (k , v , w ):
39073907 return k
39083908
39093909
@@ -4476,20 +4476,45 @@ def test_supported_types(self):
44764476 from pyspark .sql .functions import pandas_udf , PandasUDFType , array , col
44774477 df = self .data .withColumn ("arr" , array (col ("id" )))
44784478
4479- foo_udf = pandas_udf (
4479+ # Different forms of group map pandas UDF, results of these are the same
4480+
4481+ output_schema = StructType (
4482+ [StructField ('id' , LongType ()),
4483+ StructField ('v' , IntegerType ()),
4484+ StructField ('arr' , ArrayType (LongType ())),
4485+ StructField ('v1' , DoubleType ()),
4486+ StructField ('v2' , LongType ())])
4487+
4488+ udf1 = pandas_udf (
44804489 lambda pdf : pdf .assign (v1 = pdf .v * pdf .id * 1.0 , v2 = pdf .v + pdf .id ),
4481- StructType (
4482- [StructField ('id' , LongType ()),
4483- StructField ('v' , IntegerType ()),
4484- StructField ('arr' , ArrayType (LongType ())),
4485- StructField ('v1' , DoubleType ()),
4486- StructField ('v2' , LongType ())]),
4490+ output_schema ,
44874491 PandasUDFType .GROUPED_MAP
44884492 )
44894493
4490- result = df .groupby ('id' ).apply (foo_udf ).sort ('id' ).toPandas ()
4491- expected = df .toPandas ().groupby ('id' ).apply (foo_udf .func ).reset_index (drop = True )
4492- self .assertPandasEqual (expected , result )
4494+ udf2 = pandas_udf (
4495+ lambda _ , pdf : pdf .assign (v1 = pdf .v * pdf .id * 1.0 , v2 = pdf .v + pdf .id ),
4496+ output_schema ,
4497+ PandasUDFType .GROUPED_MAP
4498+ )
4499+
4500+ udf3 = pandas_udf (
4501+ lambda key , pdf : pdf .assign (id = key [0 ], v1 = pdf .v * pdf .id * 1.0 , v2 = pdf .v + pdf .id ),
4502+ output_schema ,
4503+ PandasUDFType .GROUPED_MAP
4504+ )
4505+
4506+ result1 = df .groupby ('id' ).apply (udf1 ).sort ('id' ).toPandas ()
4507+ expected1 = df .toPandas ().groupby ('id' ).apply (udf1 .func ).reset_index (drop = True )
4508+
4509+ result2 = df .groupby ('id' ).apply (udf2 ).sort ('id' ).toPandas ()
4510+ expected2 = expected1
4511+
4512+ result3 = df .groupby ('id' ).apply (udf3 ).sort ('id' ).toPandas ()
4513+ expected3 = expected1
4514+
4515+ self .assertPandasEqual (expected1 , result1 )
4516+ self .assertPandasEqual (expected2 , result2 )
4517+ self .assertPandasEqual (expected3 , result3 )
44934518
44944519 def test_register_grouped_map_udf (self ):
44954520 from pyspark .sql .functions import pandas_udf , PandasUDFType
@@ -4648,6 +4673,80 @@ def test_timestamp_dst(self):
46484673 result = df .groupby ('time' ).apply (foo_udf ).sort ('time' )
46494674 self .assertPandasEqual (df .toPandas (), result .toPandas ())
46504675
4676+ def test_udf_with_key (self ):
4677+ from pyspark .sql .functions import pandas_udf , col , PandasUDFType
4678+ df = self .data
4679+ pdf = df .toPandas ()
4680+
4681+ def foo1 (key , pdf ):
4682+ import numpy as np
4683+ assert type (key ) == tuple
4684+ assert type (key [0 ]) == np .int64
4685+
4686+ return pdf .assign (v1 = key [0 ],
4687+ v2 = pdf .v * key [0 ],
4688+ v3 = pdf .v * pdf .id ,
4689+ v4 = pdf .v * pdf .id .mean ())
4690+
4691+ def foo2 (key , pdf ):
4692+ import numpy as np
4693+ assert type (key ) == tuple
4694+ assert type (key [0 ]) == np .int64
4695+ assert type (key [1 ]) == np .int32
4696+
4697+ return pdf .assign (v1 = key [0 ],
4698+ v2 = key [1 ],
4699+ v3 = pdf .v * key [0 ],
4700+ v4 = pdf .v + key [1 ])
4701+
4702+ def foo3 (key , pdf ):
4703+ assert type (key ) == tuple
4704+ assert len (key ) == 0
4705+ return pdf .assign (v1 = pdf .v * pdf .id )
4706+
4707+ # v2 is int because numpy.int64 * pd.Series<int32> results in pd.Series<int32>
4708+ # v3 is long because pd.Series<int64> * pd.Series<int32> results in pd.Series<int64>
4709+ udf1 = pandas_udf (
4710+ foo1 ,
4711+ 'id long, v int, v1 long, v2 int, v3 long, v4 double' ,
4712+ PandasUDFType .GROUPED_MAP )
4713+
4714+ udf2 = pandas_udf (
4715+ foo2 ,
4716+ 'id long, v int, v1 long, v2 int, v3 int, v4 int' ,
4717+ PandasUDFType .GROUPED_MAP )
4718+
4719+ udf3 = pandas_udf (
4720+ foo3 ,
4721+ 'id long, v int, v1 long' ,
4722+ PandasUDFType .GROUPED_MAP )
4723+
4724+ # Test groupby column
4725+ result1 = df .groupby ('id' ).apply (udf1 ).sort ('id' , 'v' ).toPandas ()
4726+ expected1 = pdf .groupby ('id' )\
4727+ .apply (lambda x : udf1 .func ((x .id .iloc [0 ],), x ))\
4728+ .sort_values (['id' , 'v' ]).reset_index (drop = True )
4729+ self .assertPandasEqual (expected1 , result1 )
4730+
4731+ # Test groupby expression
4732+ result2 = df .groupby (df .id % 2 ).apply (udf1 ).sort ('id' , 'v' ).toPandas ()
4733+ expected2 = pdf .groupby (pdf .id % 2 )\
4734+ .apply (lambda x : udf1 .func ((x .id .iloc [0 ] % 2 ,), x ))\
4735+ .sort_values (['id' , 'v' ]).reset_index (drop = True )
4736+ self .assertPandasEqual (expected2 , result2 )
4737+
4738+ # Test complex groupby
4739+ result3 = df .groupby (df .id , df .v % 2 ).apply (udf2 ).sort ('id' , 'v' ).toPandas ()
4740+ expected3 = pdf .groupby ([pdf .id , pdf .v % 2 ])\
4741+ .apply (lambda x : udf2 .func ((x .id .iloc [0 ], (x .v % 2 ).iloc [0 ],), x ))\
4742+ .sort_values (['id' , 'v' ]).reset_index (drop = True )
4743+ self .assertPandasEqual (expected3 , result3 )
4744+
4745+ # Test empty groupby
4746+ result4 = df .groupby ().apply (udf3 ).sort ('id' , 'v' ).toPandas ()
4747+ expected4 = udf3 .func ((), pdf )
4748+ self .assertPandasEqual (expected4 , result4 )
4749+
46514750
46524751@unittest .skipIf (
46534752 not _have_pandas or not _have_pyarrow ,
0 commit comments