Skip to content

Commit 3a12413

Browse files
authored
Merge pull request scala#10118 from lrytz/deser
Prevent Function0 execution during LazyList deserialization
2 parents 88d56ab + f24c226 commit 3a12413

File tree

2 files changed

+79
-5
lines changed

2 files changed

+79
-5
lines changed

src/library/scala/collection/immutable/LazyList.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import java.lang.{StringBuilder => JStringBuilder}
1919

2020
import scala.annotation.tailrec
2121
import scala.collection.generic.SerializeEnd
22-
import scala.collection.mutable.{ArrayBuffer, Builder, ReusableBuilder, StringBuilder}
22+
import scala.collection.mutable.{Builder, ReusableBuilder, StringBuilder}
2323
import scala.language.implicitConversions
2424
import scala.runtime.Statics
2525

@@ -1353,7 +1353,7 @@ object LazyList extends SeqFactory[LazyList] {
13531353
private[this] def writeObject(out: ObjectOutputStream): Unit = {
13541354
out.defaultWriteObject()
13551355
var these = coll
1356-
while(these.knownNonEmpty) {
1356+
while (these.knownNonEmpty) {
13571357
out.writeObject(these.head)
13581358
these = these.tail
13591359
}
@@ -1363,14 +1363,17 @@ object LazyList extends SeqFactory[LazyList] {
13631363

13641364
private[this] def readObject(in: ObjectInputStream): Unit = {
13651365
in.defaultReadObject()
1366-
val init = new ArrayBuffer[A]
1366+
val init = new mutable.ListBuffer[A]
13671367
var initRead = false
13681368
while (!initRead) in.readObject match {
13691369
case SerializeEnd => initRead = true
13701370
case a => init += a.asInstanceOf[A]
13711371
}
13721372
val tail = in.readObject().asInstanceOf[LazyList[A]]
1373-
coll = init ++: tail
1373+
// scala/scala#10118: caution that no code path can evaluate `tail.state`
1374+
// before the resulting LazyList is returned
1375+
val it = init.toList.iterator
1376+
coll = newLL(stateFromIteratorConcatSuffix(it)(tail.state))
13741377
}
13751378

13761379
private[this] def readResolve(): Any = coll

test/junit/scala/collection/immutable/LazyListTest.scala

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,79 @@ import org.junit.Assert._
88

99
import scala.annotation.unused
1010
import scala.collection.mutable.{Builder, ListBuffer}
11-
import scala.tools.testkit.AssertUtil
11+
import scala.tools.testkit.{AssertUtil, ReflectUtil}
1212
import scala.util.Try
1313

1414
@RunWith(classOf[JUnit4])
1515
class LazyListTest {
1616

17+
@Test
18+
def serialization(): Unit = {
19+
import java.io._
20+
21+
def serialize(obj: AnyRef): Array[Byte] = {
22+
val buffer = new ByteArrayOutputStream
23+
val out = new ObjectOutputStream(buffer)
24+
out.writeObject(obj)
25+
buffer.toByteArray
26+
}
27+
28+
def deserialize(a: Array[Byte]): AnyRef = {
29+
val in = new ObjectInputStream(new ByteArrayInputStream(a))
30+
in.readObject
31+
}
32+
33+
def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T]
34+
35+
val l = LazyList.from(10)
36+
37+
val ld1 = serializeDeserialize(l)
38+
assertEquals(l.take(10).toList, ld1.take(10).toList)
39+
40+
l.tail.head
41+
val ld2 = serializeDeserialize(l)
42+
assertEquals(l.take(10).toList, ld2.take(10).toList)
43+
44+
LazyListTest.serializationForceCount = 0
45+
val u = LazyList.from(10).map(x => { LazyListTest.serializationForceCount += 1; x })
46+
47+
@unused def printDiff(): Unit = {
48+
val a = serialize(u)
49+
ReflectUtil.getFieldAccessible[LazyList[_]]("scala$collection$immutable$LazyList$$stateEvaluated").setBoolean(u, true)
50+
val b = serialize(u)
51+
val i = a.zip(b).indexWhere(p => p._1 != p._2)
52+
println("difference: ")
53+
println(s"val from = ${a.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
54+
println(s"val to = ${b.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
55+
}
56+
57+
// to update this test, comment-out `LazyList.writeReplace` and run `printDiff`
58+
// printDiff()
59+
60+
val from = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 0, 115, 114, 0, 33, 106, 97, 118, 97, 46)
61+
val to = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 1, 115, 114, 0, 33, 106, 97, 118, 97, 46)
62+
63+
assertEquals(LazyListTest.serializationForceCount, 0)
64+
65+
u.head
66+
assertEquals(LazyListTest.serializationForceCount, 1)
67+
68+
val data = serialize(u)
69+
var i = data.indexOfSlice(from)
70+
to.foreach(x => {data(i) = x; i += 1})
71+
72+
val ud1 = deserialize(data).asInstanceOf[LazyList[Int]]
73+
74+
// this check failed before scala/scala#10118, deserialization triggered evaluation
75+
assertEquals(LazyListTest.serializationForceCount, 1)
76+
77+
ud1.tail.head
78+
assertEquals(LazyListTest.serializationForceCount, 2)
79+
80+
u.tail.head
81+
assertEquals(LazyListTest.serializationForceCount, 3)
82+
}
83+
1784
@Test
1885
def t6727_and_t6440_and_8627(): Unit = {
1986
assertTrue(LazyList.continually(()).filter(_ => true).take(2) == Seq((), ()))
@@ -378,3 +445,7 @@ class LazyListTest {
378445
assertEquals(1, count)
379446
}
380447
}
448+
449+
object LazyListTest {
450+
var serializationForceCount = 0
451+
}

0 commit comments

Comments
 (0)