Skip to content
Merged
Next Next commit
Allow multiple spreads in function arguments
  • Loading branch information
odersky committed Sep 26, 2025
commit 0bc9f43019d0fc3ceb1a3b1c06a142dece1de856
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ object Feature:
val modularity = experimental("modularity")
val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions")
val packageObjectValues = experimental("packageObjectValues")
val multiSpreads = experimental("multiSpreads")
val subCases = experimental("subCases")

def experimentalAutoEnableFeatures(using Context): List[TermName] =
Expand Down
9 changes: 8 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,11 @@ class Definitions {
@tu lazy val throwMethod: TermSymbol = enterMethod(OpsPackageClass, nme.THROWkw,
MethodType(List(ThrowableType), NothingType))

@tu lazy val spreadMethod = enterMethod(OpsPackageClass, nme.spread,
PolyType(TypeBounds.empty :: Nil)(
tl => MethodType(AnyType :: Nil, tl.paramRefs(0))
))

@tu lazy val NothingClass: ClassSymbol = enterCompleteClassSymbol(
ScalaPackageClass, tpnme.Nothing, AbstractFinal, List(AnyType))
def NothingType: TypeRef = NothingClass.typeRef
Expand Down Expand Up @@ -519,6 +524,8 @@ class Definitions {
@tu lazy val newGenericArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newGenericArray")
@tu lazy val newArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newArray")

@tu lazy val ArraySeqBuilderModule: Symbol = requiredModule("scala.runtime.ArraySeqBuilder")

def getWrapVarargsArrayModule: Symbol = ScalaRuntimeModule

// The set of all wrap{X, Ref}Array methods, where X is a value type
Expand Down Expand Up @@ -2234,7 +2241,7 @@ class Definitions {

/** Lists core methods that don't have underlying bytecode, but are synthesized on-the-fly in every reflection universe */
@tu lazy val syntheticCoreMethods: List[TermSymbol] =
AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod)
AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod, spreadMethod)

@tu lazy val reservedScalaClassNames: Set[Name] = syntheticScalaClasses.map(_.name).toSet

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ object StdNames {
val setSymbol: N = "setSymbol"
val setType: N = "setType"
val setTypeSignature: N = "setTypeSignature"
val spread: N = "spread"
val standardInterpolator: N = "standardInterpolator"
val staticClass : N = "staticClass"
val staticModule : N = "staticModule"
Expand Down
15 changes: 10 additions & 5 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1056,17 +1056,22 @@ object Parsers {
}

/** Is current ident a `*`, and is it followed by a `)`, `, )`, `,EOF`? The latter two are not
syntactically valid, but we need to include them here for error recovery. */
syntactically valid, but we need to include them here for error recovery.
Under experimental.multiSpreads we allow `*`` followed by `,` unconditionally.
*/
def followingIsVararg(): Boolean =
in.isIdent(nme.raw.STAR) && {
val lookahead = in.LookaheadScanner()
lookahead.nextToken()
lookahead.token == RPAREN
|| lookahead.token == COMMA
&& {
lookahead.nextToken()
lookahead.token == RPAREN || lookahead.token == EOF
}
&& (
in.featureEnabled(Feature.multiSpreads)
|| {
lookahead.nextToken()
lookahead.token == RPAREN || lookahead.token == EOF
}
)
}

/** When encountering a `:`, is that in the binding of a lambda?
Expand Down
72 changes: 72 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ import config.Feature
import util.{SrcPos, Stats}
import reporting.*
import NameKinds.WildcardParamName
import typer.Applications.{spread, HasSpreads}
import typer.Implicits.SearchFailureType
import Constants.Constant
import cc.*
import dotty.tools.dotc.transform.MacroAnnotations.hasMacroAnnotation
import dotty.tools.dotc.core.NameKinds.DefaultGetterName
Expand Down Expand Up @@ -376,6 +379,73 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
case _ =>
tpt

/** Translate sequence literal containing spread operators. Example:
*
* val xs, ys: List[Int]
* [1, xs*, 2, ys*]
*
* Here the sequence literal is translated at typer to
*
* [1, spread(xs), 2, spread(ys)]
*
* This then translates to
*
* scala.runtime.ArraySeqBuilcder.ofInt(2 + xs.length + ys.length)
* .add(1)
* .addSeq(xs)
* .add(2)
* .addSeq(ys)
*
* The reason for doing a two-step typer/postTyper translation is that
* at typer, we don't have all type variables instantiated yet.
*/
private def flattenSpreads[T](tree: SeqLiteral)(using Context): Tree =
val SeqLiteral(elems, elemtpt) = tree
val elemType = elemtpt.tpe
val elemCls = elemType.classSymbol

val lengthCalls = elems.collect:
case spread(elem) => elem.select(nme.length)
val singleElemCount: Tree = Literal(Constant(elems.length - lengthCalls.length))
val totalLength =
lengthCalls.foldLeft(singleElemCount): (acc, len) =>
acc.select(defn.Int_+).appliedTo(len)

def makeBuilder(name: String) =
ref(defn.ArraySeqBuilderModule).select(name.toTermName)
def genericBuilder = makeBuilder("generic")
.appliedToType(elemType)
.appliedTo(totalLength)

val builder =
if defn.ScalaValueClasses().contains(elemCls) then
makeBuilder(s"of${elemCls.name}").appliedTo(totalLength)
else if elemCls.derivesFrom(defn.ObjectClass) then
val classTagType = defn.ClassTagClass.typeRef.appliedTo(elemType)
val classTag = atPhase(Phases.typerPhase):
ctx.typer.inferImplicitArg(classTagType, tree.span.startPos)
classTag.tpe match
case _: SearchFailureType =>
genericBuilder
case _ =>
makeBuilder("ofRef")
.appliedToType(elemType)
.appliedTo(totalLength)
.appliedTo(classTag)
else
genericBuilder

elems.foldLeft(builder): (bldr, elem) =>
elem match
case spread(arg) =>
val selector =
if arg.tpe.derivesFrom(defn.SeqClass) then "addSeq"
else "addArray"
bldr.select(selector.toTermName).appliedTo(arg)
case _ => bldr.select("add".toTermName).appliedTo(elem)
.select("result".toTermName)
end flattenSpreads

override def transform(tree: Tree)(using Context): Tree =
try tree match {
// TODO move CaseDef case lower: keep most probable trees first for performance
Expand Down Expand Up @@ -592,6 +662,8 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
case tree: RefinedTypeTree =>
Checking.checkPolyFunctionType(tree)
super.transform(tree)
case tree: SeqLiteral if tree.hasAttachment(HasSpreads) =>
flattenSpreads(tree)
case _: Quote | _: QuotePattern =>
ctx.compilationUnit.needsStaging = true
super.transform(tree)
Expand Down
43 changes: 33 additions & 10 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import Inferencing.*
import reporting.*
import Nullables.*, NullOpsDecorator.*
import config.{Feature, MigrationVersion, SourceVersion}
import util.Property

import collection.mutable
import config.Printers.{overload, typr, unapp}
Expand All @@ -42,6 +43,17 @@ import dotty.tools.dotc.inlines.Inlines
object Applications {
import tpd.*

/** Attachment key for SeqLiterals containing spreads. Eliminated at PostTyper */
val HasSpreads = new Property.StickyKey[Unit]

/** An extractor for spreads in sequence literals */
object spread:
def apply(arg: Tree, elemtpt: Tree)(using Context) =
ref(defn.spreadMethod).appliedToTypeTree(elemtpt).appliedTo(arg)
def unapply(arg: Apply)(using Context): Option[Tree] = arg match
case Apply(fn, x :: Nil) if fn.symbol == defn.spreadMethod => Some(x)
case _ => None

def extractorMember(tp: Type, name: Name)(using Context): SingleDenotation =
tp.member(name).suchThat(sym => sym.info.isParameterless && sym.info.widenExpr.isValueType)

Expand Down Expand Up @@ -797,14 +809,19 @@ trait Applications extends Compatibility {
addTyped(arg)
case _ =>
val elemFormal = formal.widenExpr.argTypesLo.head
val typedArgs =
harmonic(harmonizeArgs, elemFormal) {
args.map { arg =>
if Feature.enabled(Feature.multiSpreads)
&& !ctx.isAfterTyper && args.exists(isVarArg)
then
args.foreach: arg =>
if isVarArg(arg)
then addArg(typedArg(arg, formal), formal)
else addArg(typedArg(arg, elemFormal), elemFormal)
else
val typedArgs = harmonic(harmonizeArgs, elemFormal):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT, there is no equivalent code in the code path with multiSpreads. I understand it may be tricky, but it's at least worth calling out at the spec/SIP level. IIUC It means we get:

def foo[A](xs: A*): List[A] = xs.toList

foo(5, 5.6) // List[Double]
val doubles = List(5.6)
foo(5, doubles*) // List[AnyVal]

Basically having any spread in a vararg argument disables harmonization entirely.

args.map: arg =>
checkNoVarArg(arg)
typedArg(arg, elemFormal)
}
}
typedArgs.foreach(addArg(_, elemFormal))
typedArgs.foreach(addArg(_, elemFormal))
makeVarArg(args.length, elemFormal)
}
else args match {
Expand Down Expand Up @@ -944,12 +961,18 @@ trait Applications extends Compatibility {
typedArgBuf += typedArg
ok = ok & !typedArg.tpe.isError

def makeVarArg(n: Int, elemFormal: Type): Unit = {
def makeVarArg(n: Int, elemFormal: Type): Unit =
val args = typedArgBuf.takeRight(n).toList
typedArgBuf.dropRightInPlace(n)
val elemtpt = TypeTree(elemFormal.normalizedTupleType, inferred = true)
typedArgBuf += seqToRepeated(SeqLiteral(args, elemtpt))
}
val elemTpe = elemFormal.normalizedTupleType
val elemtpt = TypeTree(elemTpe, inferred = true)
def wrapSpread(arg: Tree): Tree = arg match
case Typed(argExpr, tpt) if tpt.tpe.isRepeatedParam => spread(argExpr, elemtpt)
case _ => arg
val args1 = args.mapConserve(wrapSpread)
val seqLit = SeqLiteral(args1, elemtpt)
if args1 ne args then seqLit.putAttachment(HasSpreads, ())
typedArgBuf += seqToRepeated(seqLit)

def harmonizeArgs(args: List[TypedArg]): List[Tree] =
// harmonize args only if resType depends on parameter types
Expand Down
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty file added by accident?

Empty file.
6 changes: 5 additions & 1 deletion library/src/scala/language.scala
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,15 @@ object language {
@compileTimeOnly("`packageObjectValues` can only be used at compile time in import statements")
object packageObjectValues

/** Experimental support for multiple spread arguments.
*/
@compileTimeOnly("`multiSpreads` can only be used at compile time in import statements")
object multiSpreads

/** Experimental support for match expressions with sub cases.
*/
@compileTimeOnly("`subCases` can only be used at compile time in import statements")
object subCases

}

/** The deprecated object contains features that are no longer officially suypported in Scala.
Expand Down
Loading