Skip to content

Commit 5994711

Browse files
committed
Merge pull request scala#2717 from retronym/ticket/6574
SI-6574 Support @tailrec for extension methods.
2 parents f4ec281 + a90d1f0 commit 5994711

File tree

6 files changed

+79
-4
lines changed

6 files changed

+79
-4
lines changed

src/compiler/scala/tools/nsc/transform/ExtensionMethods.scala

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ abstract class ExtensionMethods extends Transform with TypingTransformers {
208208
companion.moduleClass.newMethod(extensionName, origMeth.pos, origMeth.flags & ~OVERRIDE & ~PROTECTED | FINAL)
209209
setAnnotations origMeth.annotations
210210
)
211+
origMeth.removeAnnotation(TailrecClass) // it's on the extension method, now.
211212
companion.info.decls.enter(extensionMeth)
212213
}
213214

@@ -221,15 +222,16 @@ abstract class ExtensionMethods extends Transform with TypingTransformers {
221222
val extensionParams = allParameters(extensionMono)
222223
val extensionThis = gen.mkAttributedStableRef(thiz setPos extensionMeth.pos)
223224

224-
val extensionBody = (
225-
rhs
225+
val extensionBody: Tree = {
226+
val tree = rhs
226227
.substituteSymbols(origTpeParams, extensionTpeParams)
227228
.substituteSymbols(origParams, extensionParams)
228229
.substituteThis(origThis, extensionThis)
229230
.changeOwner(origMeth -> extensionMeth)
230-
)
231+
new SubstututeRecursion(origMeth, extensionMeth, unit).transform(tree)
232+
}
231233

232-
// Record the extension method ( FIXME: because... ? )
234+
// Record the extension method. Later, in `Extender#transformStats`, these will be added to the companion object.
233235
extensionDefs(companion) += atPos(tree.pos)(DefDef(extensionMeth, extensionBody))
234236

235237
// These three lines are assembling Foo.bar$extension[T1, T2, ...]($this)
@@ -264,4 +266,33 @@ abstract class ExtensionMethods extends Transform with TypingTransformers {
264266
stat
265267
}
266268
}
269+
270+
final class SubstututeRecursion(origMeth: Symbol, extensionMeth: Symbol,
271+
unit: CompilationUnit) extends TypingTransformer(unit) {
272+
override def transform(tree: Tree): Tree = tree match {
273+
// SI-6574 Rewrite recursive calls against the extension method so they can
274+
// be tail call optimized later. The tailcalls phases comes before
275+
// erasure, which performs this translation more generally at all call
276+
// sites.
277+
//
278+
// // Source
279+
// class C[C] { def meth[M](a: A) = { { <expr>: C[C'] }.meth[M'] } }
280+
//
281+
// // Translation
282+
// class C[C] { def meth[M](a: A) = { { <expr>: C[C'] }.meth[M'](a1) } }
283+
// object C { def meth$extension[M, C](this$: C[C], a: A)
284+
// = { meth$extension[M', C']({ <expr>: C[C'] })(a1) } }
285+
case treeInfo.Applied(sel @ Select(qual, _), targs, argss) if sel.symbol == origMeth =>
286+
import gen.CODE._
287+
localTyper.typedPos(tree.pos) {
288+
val allArgss = List(qual) :: argss
289+
val origThis = extensionMeth.owner.companionClass
290+
val baseType = qual.tpe.baseType(origThis)
291+
val allTargs = targs.map(_.tpe) ::: baseType.typeArgs
292+
val fun = gen.mkAttributedTypeApply(THIS(extensionMeth.owner), extensionMeth, allTargs)
293+
allArgss.foldLeft(fun)(Apply(_, _))
294+
}
295+
case _ => super.transform(tree)
296+
}
297+
}
267298
}

test/files/neg/t6574.check

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
t6574.scala:4: error: could not optimize @tailrec annotated method notTailPos$extension: it contains a recursive call not in tail position
2+
println("tail")
3+
^
4+
t6574.scala:8: error: could not optimize @tailrec annotated method differentTypeArgs$extension: it is called recursively with different type arguments
5+
{(); new Bad[String, Unit](0)}.differentTypeArgs
6+
^
7+
two errors found

test/files/neg/t6574.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class Bad[X, Y](val v: Int) extends AnyVal {
2+
@annotation.tailrec final def notTailPos[Z](a: Int)(b: String) {
3+
this.notTailPos[Z](a)(b)
4+
println("tail")
5+
}
6+
7+
@annotation.tailrec final def differentTypeArgs {
8+
{(); new Bad[String, Unit](0)}.differentTypeArgs
9+
}
10+
}

test/files/pos/t6574.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
class Bad[X, Y](val v: Int) extends AnyVal {
2+
def vv = v
3+
@annotation.tailrec final def foo[Z](a: Int)(b: String) {
4+
this.foo[Z](a)(b)
5+
}
6+
7+
@annotation.tailrec final def differentReceiver {
8+
{(); new Bad[X, Y](0)}.differentReceiver
9+
}
10+
11+
@annotation.tailrec final def dependent[Z](a: Int)(b: String): b.type = {
12+
this.dependent[Z](a)(b)
13+
}
14+
}
15+
16+
class HK[M[_]](val v: Int) extends AnyVal {
17+
def hk[N[_]]: Unit = if (false) hk[M] else ()
18+
}
19+

test/files/run/t6574b.check

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
List(5, 4, 3, 2, 1)

test/files/run/t6574b.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
object Test extends App {
2+
implicit class AnyOps(val i: Int) extends AnyVal {
3+
private def parentsOf(x: Int): List[Int] = if (x == 0) Nil else x :: parentsOf(x - 1)
4+
def parents: List[Int] = parentsOf(i)
5+
}
6+
println((5).parents)
7+
}

0 commit comments

Comments
 (0)