Skip to content

Commit 757a3a7

Browse files
committed
Merge pull request scala#1568 from retronym/ticket/6611
SI-6611 Tighten up an unsafe array optimization
2 parents 3d248ef + 092345a commit 757a3a7

File tree

7 files changed

+141
-14
lines changed

7 files changed

+141
-14
lines changed

src/compiler/scala/tools/nsc/transform/CleanUp.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ abstract class CleanUp extends Transform with ast.TreeDSL {
1515
import global._
1616
import definitions._
1717
import CODE._
18+
import treeInfo.StripCast
1819

1920
/** the following two members override abstract members in Transform */
2021
val phaseName: String = "cleanup"
@@ -618,14 +619,16 @@ abstract class CleanUp extends Transform with ast.TreeDSL {
618619
}
619620
transformApply
620621

621-
// This transform replaces Array(Predef.wrapArray(Array(...)), <tag>)
622-
// with just Array(...)
623-
case Apply(appMeth, List(Apply(wrapRefArrayMeth, List(array)), _))
624-
if (wrapRefArrayMeth.symbol == Predef_wrapRefArray &&
625-
appMeth.symbol == ArrayModule_overloadedApply.suchThat {
626-
_.tpe.resultType.dealias.typeSymbol == ObjectClass
627-
}) =>
628-
super.transform(array)
622+
// Replaces `Array(Predef.wrapArray(ArrayValue(...).$asInstanceOf[...]), <tag>)`
623+
// with just `ArrayValue(...).$asInstanceOf[...]`
624+
//
625+
// See SI-6611; we must *only* do this for literal vararg arrays.
626+
case Apply(appMeth, List(Apply(wrapRefArrayMeth, List(arg @ StripCast(ArrayValue(_, _)))), _))
627+
if wrapRefArrayMeth.symbol == Predef_wrapRefArray && appMeth.symbol == ArrayModule_genericApply =>
628+
super.transform(arg)
629+
case Apply(appMeth, List(elem0, Apply(wrapArrayMeth, List(rest @ ArrayValue(elemtpt, _)))))
630+
if wrapArrayMeth.symbol == Predef_wrapArray(elemtpt.tpe) && appMeth.symbol == ArrayModule_apply(elemtpt.tpe) =>
631+
super.transform(rest.copy(elems = elem0 :: rest.elems))
629632

630633
case _ =>
631634
super.transform(tree)

src/library/scala/Array.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ object Array extends FallbackArrayBuilding {
115115
* @param xs the elements to put in the array
116116
* @return an array containing all elements from xs.
117117
*/
118+
// Subject to a compiler optimization in Cleanup.
119+
// Array(e0, ..., en) is translated to { val a = new Array(3); a(i) = ei; a }
118120
def apply[T: ClassTag](xs: T*): Array[T] = {
119121
val array = new Array[T](xs.length)
120122
var i = 0
@@ -123,6 +125,7 @@ object Array extends FallbackArrayBuilding {
123125
}
124126

125127
/** Creates an array of `Boolean` objects */
128+
// Subject to a compiler optimization in Cleanup, see above.
126129
def apply(x: Boolean, xs: Boolean*): Array[Boolean] = {
127130
val array = new Array[Boolean](xs.length + 1)
128131
array(0) = x
@@ -132,6 +135,7 @@ object Array extends FallbackArrayBuilding {
132135
}
133136

134137
/** Creates an array of `Byte` objects */
138+
// Subject to a compiler optimization in Cleanup, see above.
135139
def apply(x: Byte, xs: Byte*): Array[Byte] = {
136140
val array = new Array[Byte](xs.length + 1)
137141
array(0) = x
@@ -141,6 +145,7 @@ object Array extends FallbackArrayBuilding {
141145
}
142146

143147
/** Creates an array of `Short` objects */
148+
// Subject to a compiler optimization in Cleanup, see above.
144149
def apply(x: Short, xs: Short*): Array[Short] = {
145150
val array = new Array[Short](xs.length + 1)
146151
array(0) = x
@@ -150,6 +155,7 @@ object Array extends FallbackArrayBuilding {
150155
}
151156

152157
/** Creates an array of `Char` objects */
158+
// Subject to a compiler optimization in Cleanup, see above.
153159
def apply(x: Char, xs: Char*): Array[Char] = {
154160
val array = new Array[Char](xs.length + 1)
155161
array(0) = x
@@ -159,6 +165,7 @@ object Array extends FallbackArrayBuilding {
159165
}
160166

161167
/** Creates an array of `Int` objects */
168+
// Subject to a compiler optimization in Cleanup, see above.
162169
def apply(x: Int, xs: Int*): Array[Int] = {
163170
val array = new Array[Int](xs.length + 1)
164171
array(0) = x
@@ -168,6 +175,7 @@ object Array extends FallbackArrayBuilding {
168175
}
169176

170177
/** Creates an array of `Long` objects */
178+
// Subject to a compiler optimization in Cleanup, see above.
171179
def apply(x: Long, xs: Long*): Array[Long] = {
172180
val array = new Array[Long](xs.length + 1)
173181
array(0) = x
@@ -177,6 +185,7 @@ object Array extends FallbackArrayBuilding {
177185
}
178186

179187
/** Creates an array of `Float` objects */
188+
// Subject to a compiler optimization in Cleanup, see above.
180189
def apply(x: Float, xs: Float*): Array[Float] = {
181190
val array = new Array[Float](xs.length + 1)
182191
array(0) = x
@@ -186,6 +195,7 @@ object Array extends FallbackArrayBuilding {
186195
}
187196

188197
/** Creates an array of `Double` objects */
198+
// Subject to a compiler optimization in Cleanup, see above.
189199
def apply(x: Double, xs: Double*): Array[Double] = {
190200
val array = new Array[Double](xs.length + 1)
191201
array(0) = x

src/reflect/scala/reflect/internal/Definitions.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,13 @@ trait Definitions extends api.StandardDefinitions {
340340
lazy val PredefModule = requiredModule[scala.Predef.type]
341341
lazy val PredefModuleClass = PredefModule.moduleClass
342342

343-
def Predef_classOf = getMemberMethod(PredefModule, nme.classOf)
344-
def Predef_identity = getMemberMethod(PredefModule, nme.identity)
345-
def Predef_conforms = getMemberMethod(PredefModule, nme.conforms)
346-
def Predef_wrapRefArray = getMemberMethod(PredefModule, nme.wrapRefArray)
347-
def Predef_??? = getMemberMethod(PredefModule, nme.???)
348-
def Predef_implicitly = getMemberMethod(PredefModule, nme.implicitly)
343+
def Predef_classOf = getMemberMethod(PredefModule, nme.classOf)
344+
def Predef_identity = getMemberMethod(PredefModule, nme.identity)
345+
def Predef_conforms = getMemberMethod(PredefModule, nme.conforms)
346+
def Predef_wrapRefArray = getMemberMethod(PredefModule, nme.wrapRefArray)
347+
def Predef_wrapArray(tp: Type) = getMemberMethod(PredefModule, wrapArrayMethodName(tp))
348+
def Predef_??? = getMemberMethod(PredefModule, nme.???)
349+
def Predef_implicitly = getMemberMethod(PredefModule, nme.implicitly)
349350

350351
/** Is `sym` a member of Predef with the given name?
351352
* Note: DON't replace this by sym == Predef_conforms/etc, as Predef_conforms is a `def`
@@ -470,6 +471,8 @@ trait Definitions extends api.StandardDefinitions {
470471
// arrays and their members
471472
lazy val ArrayModule = requiredModule[scala.Array.type]
472473
lazy val ArrayModule_overloadedApply = getMemberMethod(ArrayModule, nme.apply)
474+
def ArrayModule_genericApply = ArrayModule_overloadedApply.suchThat(_.paramss.flatten.last.tpe.typeSymbol == ClassTagClass) // [T: ClassTag](xs: T*): Array[T]
475+
def ArrayModule_apply(tp: Type) = ArrayModule_overloadedApply.suchThat(_.tpe.resultType =:= arrayType(tp)) // (p1: AnyVal1, ps: AnyVal1*): Array[AnyVal1]
473476
lazy val ArrayClass = getRequiredClass("scala.Array") // requiredClass[scala.Array[_]]
474477
lazy val Array_apply = getMemberMethod(ArrayClass, nme.apply)
475478
lazy val Array_update = getMemberMethod(ArrayClass, nme.update)

src/reflect/scala/reflect/internal/TreeInfo.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,20 @@ abstract class TreeInfo {
265265
tree
266266
}
267267

268+
/** Strips layers of `.asInstanceOf[T]` / `_.$asInstanceOf[T]()` from an expression */
269+
def stripCast(tree: Tree): Tree = tree match {
270+
case TypeApply(sel @ Select(inner, _), _) if isCastSymbol(sel.symbol) =>
271+
stripCast(inner)
272+
case Apply(TypeApply(sel @ Select(inner, _), _), Nil) if isCastSymbol(sel.symbol) =>
273+
stripCast(inner)
274+
case t =>
275+
t
276+
}
277+
278+
object StripCast {
279+
def unapply(tree: Tree): Some[Tree] = Some(stripCast(tree))
280+
}
281+
268282
/** Is tree a self or super constructor call? */
269283
def isSelfOrSuperConstrCall(tree: Tree) = {
270284
// stripNamedApply for SI-3584: adaptToImplicitMethod in Typers creates a special context
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Method call statistics:
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import scala.tools.partest.instrumented.Instrumentation._
2+
3+
object Test {
4+
def main(args: Array[String]) {
5+
startProfiling()
6+
7+
// tests optimization in Cleanup for varargs reference arrays
8+
Array("")
9+
10+
11+
Array(true)
12+
Array(true, false)
13+
Array(1: Byte)
14+
Array(1: Byte, 2: Byte)
15+
Array(1: Short)
16+
Array(1: Short, 2: Short)
17+
Array(1)
18+
Array(1, 2)
19+
Array(1L)
20+
Array(1L, 2L)
21+
Array(1d)
22+
Array(1d, 2d)
23+
Array(1f)
24+
Array(1f, 2f)
25+
26+
/* Not currently optimized:
27+
Array[Int](1, 2) etc
28+
Array(())
29+
Array((), ())
30+
*/
31+
32+
stopProfiling()
33+
printStatistics()
34+
}
35+
}

test/files/run/t6611.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
object Test extends App {
2+
locally {
3+
val a = Array("1")
4+
val a2 = Array(a: _*)
5+
assert(a ne a2)
6+
}
7+
8+
locally {
9+
val a = Array("1": Object)
10+
val a2 = Array(a: _*)
11+
assert(a ne a2)
12+
}
13+
14+
locally {
15+
val a = Array(true)
16+
val a2 = Array(a: _*)
17+
assert(a ne a2)
18+
}
19+
20+
locally {
21+
val a = Array(1: Short)
22+
val a2 = Array(a: _*)
23+
assert(a ne a2)
24+
}
25+
26+
locally {
27+
val a = Array(1: Byte)
28+
val a2 = Array(a: _*)
29+
assert(a ne a2)
30+
}
31+
32+
locally {
33+
val a = Array(1)
34+
val a2 = Array(a: _*)
35+
assert(a ne a2)
36+
}
37+
38+
locally {
39+
val a = Array(1L)
40+
val a2 = Array(a: _*)
41+
assert(a ne a2)
42+
}
43+
44+
locally {
45+
val a = Array(1f)
46+
val a2 = Array(a: _*)
47+
assert(a ne a2)
48+
}
49+
50+
locally {
51+
val a = Array(1d)
52+
val a2 = Array(a: _*)
53+
assert(a ne a2)
54+
}
55+
56+
locally {
57+
val a = Array(())
58+
val a2 = Array(a: _*)
59+
assert(a ne a2)
60+
}
61+
}

0 commit comments

Comments
 (0)