Skip to content

Commit 6eba655

Browse files
sarutakmarmbrus
authored andcommitted
[SPARK-12404][SQL] Ensure objects passed to StaticInvoke is Serializable
Now `StaticInvoke` receives `Any` as a object and `StaticInvoke` can be serialized but sometimes the object passed is not serializable. For example, following code raises Exception because `RowEncoder#extractorsFor` invoked indirectly makes `StaticInvoke`. ``` case class TimestampContainer(timestamp: java.sql.Timestamp) val rdd = sc.parallelize(1 to 2).map(_ => TimestampContainer(System.currentTimeMillis)) val df = rdd.toDF val ds = df.as[TimestampContainer] val rdd2 = ds.rdd <----------------- invokes extractorsFor indirectory ``` I'll add test cases. Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp> Author: Michael Armbrust <michael@databricks.com> Closes #10357 from sarutak/SPARK-12404.
1 parent 41ee7c5 commit 6eba655

File tree

6 files changed

+88
-26
lines changed

6 files changed

+88
-26
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,15 @@ object JavaTypeInference {
194194

195195
case c if c == classOf[java.sql.Date] =>
196196
StaticInvoke(
197-
DateTimeUtils,
197+
DateTimeUtils.getClass,
198198
ObjectType(c),
199199
"toJavaDate",
200200
getPath :: Nil,
201201
propagateNull = true)
202202

203203
case c if c == classOf[java.sql.Timestamp] =>
204204
StaticInvoke(
205-
DateTimeUtils,
205+
DateTimeUtils.getClass,
206206
ObjectType(c),
207207
"toJavaTimestamp",
208208
getPath :: Nil,
@@ -276,7 +276,7 @@ object JavaTypeInference {
276276
ObjectType(classOf[Array[Any]]))
277277

278278
StaticInvoke(
279-
ArrayBasedMapData,
279+
ArrayBasedMapData.getClass,
280280
ObjectType(classOf[JMap[_, _]]),
281281
"toJavaMap",
282282
keyData :: valueData :: Nil)
@@ -341,21 +341,21 @@ object JavaTypeInference {
341341

342342
case c if c == classOf[java.sql.Timestamp] =>
343343
StaticInvoke(
344-
DateTimeUtils,
344+
DateTimeUtils.getClass,
345345
TimestampType,
346346
"fromJavaTimestamp",
347347
inputObject :: Nil)
348348

349349
case c if c == classOf[java.sql.Date] =>
350350
StaticInvoke(
351-
DateTimeUtils,
351+
DateTimeUtils.getClass,
352352
DateType,
353353
"fromJavaDate",
354354
inputObject :: Nil)
355355

356356
case c if c == classOf[java.math.BigDecimal] =>
357357
StaticInvoke(
358-
Decimal,
358+
Decimal.getClass,
359359
DecimalType.SYSTEM_DEFAULT,
360360
"apply",
361361
inputObject :: Nil)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,15 @@ object ScalaReflection extends ScalaReflection {
223223

224224
case t if t <:< localTypeOf[java.sql.Date] =>
225225
StaticInvoke(
226-
DateTimeUtils,
226+
DateTimeUtils.getClass,
227227
ObjectType(classOf[java.sql.Date]),
228228
"toJavaDate",
229229
getPath :: Nil,
230230
propagateNull = true)
231231

232232
case t if t <:< localTypeOf[java.sql.Timestamp] =>
233233
StaticInvoke(
234-
DateTimeUtils,
234+
DateTimeUtils.getClass,
235235
ObjectType(classOf[java.sql.Timestamp]),
236236
"toJavaTimestamp",
237237
getPath :: Nil,
@@ -287,7 +287,7 @@ object ScalaReflection extends ScalaReflection {
287287
ObjectType(classOf[Array[Any]]))
288288

289289
StaticInvoke(
290-
scala.collection.mutable.WrappedArray,
290+
scala.collection.mutable.WrappedArray.getClass,
291291
ObjectType(classOf[Seq[_]]),
292292
"make",
293293
arrayData :: Nil)
@@ -315,7 +315,7 @@ object ScalaReflection extends ScalaReflection {
315315
ObjectType(classOf[Array[Any]]))
316316

317317
StaticInvoke(
318-
ArrayBasedMapData,
318+
ArrayBasedMapData.getClass,
319319
ObjectType(classOf[Map[_, _]]),
320320
"toScalaMap",
321321
keyData :: valueData :: Nil)
@@ -548,28 +548,28 @@ object ScalaReflection extends ScalaReflection {
548548

549549
case t if t <:< localTypeOf[java.sql.Timestamp] =>
550550
StaticInvoke(
551-
DateTimeUtils,
551+
DateTimeUtils.getClass,
552552
TimestampType,
553553
"fromJavaTimestamp",
554554
inputObject :: Nil)
555555

556556
case t if t <:< localTypeOf[java.sql.Date] =>
557557
StaticInvoke(
558-
DateTimeUtils,
558+
DateTimeUtils.getClass,
559559
DateType,
560560
"fromJavaDate",
561561
inputObject :: Nil)
562562

563563
case t if t <:< localTypeOf[BigDecimal] =>
564564
StaticInvoke(
565-
Decimal,
565+
Decimal.getClass,
566566
DecimalType.SYSTEM_DEFAULT,
567567
"apply",
568568
inputObject :: Nil)
569569

570570
case t if t <:< localTypeOf[java.math.BigDecimal] =>
571571
StaticInvoke(
572-
Decimal,
572+
Decimal.getClass,
573573
DecimalType.SYSTEM_DEFAULT,
574574
"apply",
575575
inputObject :: Nil)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,21 @@ object RowEncoder {
6161

6262
case TimestampType =>
6363
StaticInvoke(
64-
DateTimeUtils,
64+
DateTimeUtils.getClass,
6565
TimestampType,
6666
"fromJavaTimestamp",
6767
inputObject :: Nil)
6868

6969
case DateType =>
7070
StaticInvoke(
71-
DateTimeUtils,
71+
DateTimeUtils.getClass,
7272
DateType,
7373
"fromJavaDate",
7474
inputObject :: Nil)
7575

7676
case _: DecimalType =>
7777
StaticInvoke(
78-
Decimal,
78+
Decimal.getClass,
7979
DecimalType.SYSTEM_DEFAULT,
8080
"apply",
8181
inputObject :: Nil)
@@ -172,14 +172,14 @@ object RowEncoder {
172172

173173
case TimestampType =>
174174
StaticInvoke(
175-
DateTimeUtils,
175+
DateTimeUtils.getClass,
176176
ObjectType(classOf[java.sql.Timestamp]),
177177
"toJavaTimestamp",
178178
input :: Nil)
179179

180180
case DateType =>
181181
StaticInvoke(
182-
DateTimeUtils,
182+
DateTimeUtils.getClass,
183183
ObjectType(classOf[java.sql.Date]),
184184
"toJavaDate",
185185
input :: Nil)
@@ -197,7 +197,7 @@ object RowEncoder {
197197
"array",
198198
ObjectType(classOf[Array[_]]))
199199
StaticInvoke(
200-
scala.collection.mutable.WrappedArray,
200+
scala.collection.mutable.WrappedArray.getClass,
201201
ObjectType(classOf[Seq[_]]),
202202
"make",
203203
arrayData :: Nil)
@@ -210,7 +210,7 @@ object RowEncoder {
210210
val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType))
211211

212212
StaticInvoke(
213-
ArrayBasedMapData,
213+
ArrayBasedMapData.getClass,
214214
ObjectType(classOf[Map[_, _]]),
215215
"toScalaMap",
216216
keyData :: valueData :: Nil)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,14 @@ import org.apache.spark.sql.types._
4242
* of calling the function.
4343
*/
4444
case class StaticInvoke(
45-
staticObject: Any,
45+
staticObject: Class[_],
4646
dataType: DataType,
4747
functionName: String,
4848
arguments: Seq[Expression] = Nil,
4949
propagateNull: Boolean = true) extends Expression {
5050

51-
val objectName = staticObject match {
52-
case c: Class[_] => c.getName
53-
case other => other.getClass.getName.stripSuffix("$")
54-
}
51+
val objectName = staticObject.getName.stripSuffix("$")
52+
5553
override def nullable: Boolean = true
5654
override def children: Seq[Expression] = arguments
5755

sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.apache.spark.sql.test.TestSQLContext;
4040
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
4141
import org.apache.spark.sql.catalyst.expressions.GenericRow;
42+
import org.apache.spark.sql.types.DecimalType;
4243
import org.apache.spark.sql.types.StructType;
4344

4445
import static org.apache.spark.sql.functions.*;
@@ -608,6 +609,44 @@ public int hashCode() {
608609
}
609610
}
610611

612+
public class SimpleJavaBean2 implements Serializable {
613+
private Timestamp a;
614+
private Date b;
615+
private java.math.BigDecimal c;
616+
617+
public Timestamp getA() { return a; }
618+
619+
public void setA(Timestamp a) { this.a = a; }
620+
621+
public Date getB() { return b; }
622+
623+
public void setB(Date b) { this.b = b; }
624+
625+
public java.math.BigDecimal getC() { return c; }
626+
627+
public void setC(java.math.BigDecimal c) { this.c = c; }
628+
629+
@Override
630+
public boolean equals(Object o) {
631+
if (this == o) return true;
632+
if (o == null || getClass() != o.getClass()) return false;
633+
634+
SimpleJavaBean that = (SimpleJavaBean) o;
635+
636+
if (!a.equals(that.a)) return false;
637+
if (!b.equals(that.b)) return false;
638+
return c.equals(that.c);
639+
}
640+
641+
@Override
642+
public int hashCode() {
643+
int result = a.hashCode();
644+
result = 31 * result + b.hashCode();
645+
result = 31 * result + c.hashCode();
646+
return result;
647+
}
648+
}
649+
611650
public class NestedJavaBean implements Serializable {
612651
private SimpleJavaBean a;
613652

@@ -689,4 +728,17 @@ public void testJavaBeanEncoder() {
689728
.as(Encoders.bean(SimpleJavaBean.class));
690729
Assert.assertEquals(data, ds3.collectAsList());
691730
}
731+
732+
@Test
733+
public void testJavaBeanEncoder2() {
734+
// This is a regression test of SPARK-12404
735+
OuterScopes.addOuterScope(this);
736+
SimpleJavaBean2 obj = new SimpleJavaBean2();
737+
obj.setA(new Timestamp(0));
738+
obj.setB(new Date(0));
739+
obj.setC(java.math.BigDecimal.valueOf(1));
740+
Dataset<SimpleJavaBean2> ds =
741+
context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
742+
ds.collect();
743+
}
692744
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql
1919

2020
import java.io.{ObjectInput, ObjectOutput, Externalizable}
21+
import java.sql.{Date, Timestamp}
2122

2223
import scala.language.postfixOps
2324

@@ -42,6 +43,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
4243
1, 1, 1)
4344
}
4445

46+
47+
test("SPARK-12404: Datatype Helper Serializablity") {
48+
val ds = sparkContext.parallelize((
49+
new Timestamp(0),
50+
new Date(0),
51+
java.math.BigDecimal.valueOf(1),
52+
scala.math.BigDecimal(1)) :: Nil).toDS()
53+
54+
ds.collect()
55+
}
56+
4557
test("collect, first, and take should use encoders for serialization") {
4658
val item = NonSerializableCaseClass("abcd")
4759
val ds = Seq(item).toDS()

0 commit comments

Comments
 (0)