From ae0b55ab69c8b9cd18b709fd2f66b2e8071e7caa Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 23 Jan 2016 18:23:41 -0800 Subject: [PATCH 1/2] Checks row length when converting Java arrays to Python rows --- python/pyspark/sql/tests.py | 9 +++++++++ .../scala/org/apache/spark/sql/execution/python.scala | 7 ++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ae8620274dd2..8122ab83ed87 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -364,6 +364,15 @@ def test_infer_schema_to_local(self): df3 = self.sqlCtx.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) + def test_create_dataframe_schema_mismatch(self): + input = [Row(a=1)] + rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) + schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) + df = self.sqlCtx.createDataFrame(rdd, schema) + message = ".*assertion failed: Row length 1 and schema length 2 don't match.*" + with self.assertRaisesRegexp(Exception, message): + df.show() + def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index 41e35fd724cd..5fa4e37cbd69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -220,7 +220,12 @@ object EvaluatePython { ArrayBasedMapData(keys, values) case (c, StructType(fields)) if c.getClass.isArray => - new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map { + val array = c.asInstanceOf[Array[_]] + assert( + array.length == fields.length, + s"Row length ${array.length} and schema length ${fields.length} don't match" + ) + new GenericInternalRow(array.zip(fields).map { case (e, f) => fromJava(e, f.dataType) }) From ad8efa122c21be675111c1bbaeae607058e5c8fa Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 24 Jan 2016 16:43:13 -0800 Subject: [PATCH 2/2] Addresses PR comment --- python/pyspark/sql/tests.py | 2 +- .../scala/org/apache/spark/sql/execution/python.scala | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8122ab83ed87..7593b991a780 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -369,7 +369,7 @@ def test_create_dataframe_schema_mismatch(self): rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) df = self.sqlCtx.createDataFrame(rdd, schema) - message = ".*assertion failed: Row length 1 and schema length 2 don't match.*" + message = ".*Input row doesn't have expected number of values required by the schema.*" with self.assertRaisesRegexp(Exception, message): df.show() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index 5fa4e37cbd69..e3a016e18db8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -221,10 +221,12 @@ object EvaluatePython { case (c, StructType(fields)) if c.getClass.isArray => val array = c.asInstanceOf[Array[_]] - assert( - array.length == fields.length, - s"Row length ${array.length} and schema length ${fields.length} don't match" - ) + if (array.length != fields.length) { + throw new IllegalStateException( + s"Input row doesn't have expected number of values required by the schema. " + + s"${fields.length} fields are required while ${array.length} values are provided." + ) + } new GenericInternalRow(array.zip(fields).map { case (e, f) => fromJava(e, f.dataType) })