Skip to content

Commit e29f833

Browse files
committed
Modify tests for unsupported types.
1 parent bcadac8 commit e29f833

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

python/pyspark/sql/tests.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3194,10 +3194,11 @@ def create_pandas_data_frame(self):
31943194
return pd.DataFrame(data=data_dict)
31953195

31963196
def test_unsupported_datatype(self):
3197-
schema = StructType([StructField("decimal", DecimalType(), True)])
3197+
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
31983198
df = self.spark.createDataFrame([(None,)], schema=schema)
31993199
with QuietTest(self.sc):
3200-
self.assertRaises(Exception, lambda: df.toPandas())
3200+
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
3201+
df.toPandas()
32013202

32023203
def test_null_conversion(self):
32033204
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
@@ -3733,12 +3734,12 @@ def test_vectorized_udf_varargs(self):
37333734

37343735
def test_vectorized_udf_unsupported_types(self):
37353736
from pyspark.sql.functions import pandas_udf, col
3736-
schema = StructType([StructField("dt", DecimalType(), True)])
3737+
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
37373738
df = self.spark.createDataFrame([(None,)], schema=schema)
3738-
f = pandas_udf(lambda x: x, DecimalType())
3739+
f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
37393740
with QuietTest(self.sc):
37403741
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
3741-
df.select(f(col('dt'))).collect()
3742+
df.select(f(col('map'))).collect()
37423743

37433744
def test_vectorized_udf_null_date(self):
37443745
from pyspark.sql.functions import pandas_udf, col
@@ -4032,7 +4033,8 @@ def test_wrong_args(self):
40324033
def test_unsupported_types(self):
40334034
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
40344035
schema = StructType(
4035-
[StructField("id", LongType(), True), StructField("dt", DecimalType(), True)])
4036+
[StructField("id", LongType(), True),
4037+
StructField("map", MapType(StringType(), IntegerType()), True)])
40364038
df = self.spark.createDataFrame([(1, None,)], schema=schema)
40374039
f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP)
40384040
with QuietTest(self.sc):

0 commit comments

Comments
 (0)