Skip to content
Merged
Prev Previous commit
Next Next commit
Make sure spreads are evaluated only once
We need to access them twice because we first need to take their length, then
append them to the buffer. If a spread might have side effects, lift all side-effecting
arguments out in the order of occurrence.
  • Loading branch information
odersky committed Sep 26, 2025
commit 1f333bc4716b817d2bbedc7e0cf56785e89ef3c3
105 changes: 63 additions & 42 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ import config.Printers.typr
import config.Feature
import util.{SrcPos, Stats}
import reporting.*
import NameKinds.WildcardParamName
import NameKinds.{WildcardParamName, TempResultName}
import typer.Applications.{spread, HasSpreads}
import typer.Implicits.SearchFailureType
import Constants.Constant
import cc.*
import dotty.tools.dotc.transform.MacroAnnotations.hasMacroAnnotation
import dotty.tools.dotc.core.NameKinds.DefaultGetterName
import ast.TreeInfo

object PostTyper {
val name: String = "posttyper"
Expand Down Expand Up @@ -379,6 +380,25 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
case _ =>
tpt

private def evalSpreadsOnce(trees: List[Tree])(within: List[Tree] => Tree)(using Context): Tree =
if trees.exists:
case spread(elem) => !(exprPurity(elem) >= TreeInfo.Idempotent)
case _ => false
then
val lifted = new mutable.ListBuffer[ValDef]
def liftIfImpure(tree: Tree): Tree = tree match
case tree @ Apply(fn, args) if fn.symbol == defn.spreadMethod =>
cpy.Apply(tree)(fn, args.mapConserve(liftIfImpure))
case _ if tpd.exprPurity(tree) >= TreeInfo.Idempotent =>
tree
case _ =>
val vdef = SyntheticValDef(TempResultName.fresh(), tree)
lifted += vdef
Ident(vdef.namedType)
val pureTrees = trees.mapConserve(liftIfImpure)
Block(lifted.toList, within(pureTrees))
else within(trees)

/** Translate sequence literal containing spread operators. Example:
*
* val xs, ys: List[Int]
Expand All @@ -400,50 +420,51 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
* at typer, we don't have all type variables instantiated yet.
*/
private def flattenSpreads[T](tree: SeqLiteral)(using Context): Tree =
val SeqLiteral(elems, elemtpt) = tree
val SeqLiteral(rawElems, elemtpt) = tree
val elemType = elemtpt.tpe
val elemCls = elemType.classSymbol

val lengthCalls = elems.collect:
case spread(elem) => elem.select(nme.length)
val singleElemCount: Tree = Literal(Constant(elems.length - lengthCalls.length))
val totalLength =
lengthCalls.foldLeft(singleElemCount): (acc, len) =>
acc.select(defn.Int_+).appliedTo(len)

def makeBuilder(name: String) =
ref(defn.ArraySeqBuilderModule).select(name.toTermName)
def genericBuilder = makeBuilder("generic")
.appliedToType(elemType)
.appliedTo(totalLength)

val builder =
if defn.ScalaValueClasses().contains(elemCls) then
makeBuilder(s"of${elemCls.name}").appliedTo(totalLength)
else if elemCls.derivesFrom(defn.ObjectClass) then
val classTagType = defn.ClassTagClass.typeRef.appliedTo(elemType)
val classTag = atPhase(Phases.typerPhase):
ctx.typer.inferImplicitArg(classTagType, tree.span.startPos)
classTag.tpe match
case _: SearchFailureType =>
genericBuilder
case _ =>
makeBuilder("ofRef")
.appliedToType(elemType)
.appliedTo(totalLength)
.appliedTo(classTag)
else
genericBuilder

elems.foldLeft(builder): (bldr, elem) =>
elem match
case spread(arg) =>
val selector =
if arg.tpe.derivesFrom(defn.SeqClass) then "addSeq"
else "addArray"
bldr.select(selector.toTermName).appliedTo(arg)
case _ => bldr.select("add".toTermName).appliedTo(elem)
.select("result".toTermName)
evalSpreadsOnce(rawElems): elems =>
val lengthCalls = elems.collect:
case spread(elem) => elem.select(nme.length)
val singleElemCount: Tree = Literal(Constant(elems.length - lengthCalls.length))
val totalLength =
lengthCalls.foldLeft(singleElemCount): (acc, len) =>
acc.select(defn.Int_+).appliedTo(len)

def makeBuilder(name: String) =
ref(defn.ArraySeqBuilderModule).select(name.toTermName)
def genericBuilder = makeBuilder("generic")
.appliedToType(elemType)
.appliedTo(totalLength)

val builder =
if defn.ScalaValueClasses().contains(elemCls) then
makeBuilder(s"of${elemCls.name}").appliedTo(totalLength)
else if elemCls.derivesFrom(defn.ObjectClass) then
val classTagType = defn.ClassTagClass.typeRef.appliedTo(elemType)
val classTag = atPhase(Phases.typerPhase):
ctx.typer.inferImplicitArg(classTagType, tree.span.startPos)
classTag.tpe match
case _: SearchFailureType =>
genericBuilder
case _ =>
makeBuilder("ofRef")
.appliedToType(elemType)
.appliedTo(totalLength)
.appliedTo(classTag)
else
genericBuilder

elems.foldLeft(builder): (bldr, elem) =>
elem match
case spread(arg) =>
val selector =
if arg.tpe.derivesFrom(defn.SeqClass) then "addSeq"
else "addArray"
bldr.select(selector.toTermName).appliedTo(arg)
case _ => bldr.select("add".toTermName).appliedTo(elem)
.select("result".toTermName)
end flattenSpreads

override def transform(tree: Tree)(using Context): Tree =
Expand Down
12 changes: 12 additions & 0 deletions tests/run/spreads.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
ArraySeq(1, 2, 3)
ArraySeq(1, 2, 3)
ArraySeq(1, 2, 1, 2, 3)
ArraySeq(1, 2, 1, 2, 3)
ArraySeq(1, 1, 2, 3, 2)
ArraySeq(1, 1, 2, 3, 2, 1, 2, 3, 3)
ArraySeq(1, 1, 2, 3, true, A, false)
ArraySeq(1, 1, 2, 3, 2)
one
one-two-three
two
ArraySeq(1, 1, 2, 3, 2)
11 changes: 11 additions & 0 deletions tests/run/spreads.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,20 @@ def useInt(xs: Int*) = ???

val xs = List(1, 2, 3)
val ys = List("A")
val ao = Option(1.0).toList

val x: Unit = use[Int](1, 2, xs*)
val y = use(1, 2, xs*)
use(1, xs*, 2)
use(1, xs*, 2, xs*, 3)
use(1, xs*, true, ys*, false)
use(1, identity(xs)*, 2)

def one = { println("one"); 1 }
def two = { println("two"); 2 }
def oneTwoThree = { println("one-two-three"); xs }
use(one, oneTwoThree*, two)
//use(1.0, ao*, 2.0)