Skip to content

Commit 5d7d644

Browse files
committed
typedFunction undoes eta-expansion regardless of expected type
When recovering missing argument types for an eta-expanded method value, rework the expected type to a method type.
1 parent 8e32d00 commit 5d7d644

File tree

3 files changed

+41
-45
lines changed

3 files changed

+41
-45
lines changed

src/compiler/scala/tools/nsc/typechecker/EtaExpansion.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,8 @@ import symtab.Flags._
1515
* @version 1.0
1616
*/
1717
trait EtaExpansion { self: Analyzer =>
18-
1918
import global._
2019

21-
object etaExpansion {
22-
private def isMatch(vparam: ValDef, arg: Tree) = arg match {
23-
case Ident(name) => vparam.name == name
24-
case _ => false
25-
}
26-
27-
def unapply(tree: Tree): Option[(List[ValDef], Tree, List[Tree])] = tree match {
28-
case Function(vparams, Apply(fn, args)) if (vparams corresponds args)(isMatch) =>
29-
Some((vparams, fn, args))
30-
case _ =>
31-
None
32-
}
33-
}
34-
3520
/** <p>
3621
* Expand partial function applications of type `type`.
3722
* </p><pre>

src/compiler/scala/tools/nsc/typechecker/Typers.scala

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,7 +2841,8 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
28412841
* - a type with a Single Abstract Method (under -Xexperimental for now).
28422842
*/
28432843
private def typedFunction(fun: Function, mode: Mode, pt: Type): Tree = {
2844-
val numVparams = fun.vparams.length
2844+
val vparams = fun.vparams
2845+
val numVparams = vparams.length
28452846
val FunctionSymbol =
28462847
if (numVparams > definitions.MaxFunctionArity) NoSymbol
28472848
else FunctionClass(numVparams)
@@ -2863,37 +2864,20 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
28632864
* TODO: handle vararg sams?
28642865
*/
28652866
val ptNorm =
2866-
if (samMatchesFunctionBasedOnArity(sam, fun.vparams)) samToFunctionType(pt, sam)
2867+
if (samMatchesFunctionBasedOnArity(sam, vparams)) samToFunctionType(pt, sam)
28672868
else pt
28682869
val (argpts, respt) =
28692870
ptNorm baseType FunctionSymbol match {
28702871
case TypeRef(_, FunctionSymbol, args :+ res) => (args, res)
2871-
case _ => (fun.vparams map (if (pt == ErrorType) (_ => ErrorType) else (_ => NoType)), WildcardType)
2872+
case _ => (vparams map (if (pt == ErrorType) (_ => ErrorType) else (_ => NoType)), WildcardType)
28722873
}
28732874

2874-
2875-
// if the function is `(a1: T1, ..., aN: TN) => fun(a1,..., aN)`, where Ti are not all fully defined,
2876-
// type `fun` directly
2877-
def typeUnEtaExpanded: Type = fun match {
2878-
case etaExpansion(_, fn, _) =>
2879-
silent(_.typed(fn, mode.forFunMode, pt)) filter (_ => context.undetparams.isEmpty) map { fn1 =>
2880-
// if context.undetparams is not empty, the function was polymorphic,
2881-
// so we need the missing arguments to infer its type. See #871
2882-
val ftpe = normalize(fn1.tpe) baseType FunctionClass(numVparams)
2883-
// println(s"typeUnEtaExpanded $fn : ${fn1.tpe} (unwrapped $fun) --> normalized: $ftpe")
2884-
2885-
if (isFunctionType(ftpe) && isFullyDefined(ftpe)) ftpe
2886-
else NoType
2887-
} orElse { _ => NoType }
2888-
case _ => NoType
2889-
}
2890-
28912875
if (!FunctionSymbol.exists) MaxFunctionArityError(fun)
28922876
else if (argpts.lengthCompare(numVparams) != 0) WrongNumberOfParametersError(fun, argpts)
28932877
else {
28942878
val paramsMissingType = mutable.ArrayBuffer.empty[ValDef] //.sizeHint(numVparams) probably useless, since initial size is 16 and max fun arity is 22
28952879
// first, try to define param types from expected function's arg types if needed
2896-
foreach2(fun.vparams, argpts) { (vparam, argpt) =>
2880+
foreach2(vparams, argpts) { (vparam, argpt) =>
28972881
if (vparam.tpt isEmpty) {
28982882
if (isFullyDefined(argpt)) vparam.tpt setType argpt
28992883
else paramsMissingType += vparam
@@ -2902,12 +2886,29 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
29022886
}
29032887
}
29042888

2905-
// if we had missing param types, see if we can undo eta-expansion and recover type info
2906-
val expectedFunTypeBeforeEtaExpansion =
2907-
if (paramsMissingType.isEmpty) NoType
2908-
else typeUnEtaExpanded
2889+
// If we're typing `(a1: T1, ..., aN: TN) => m(a1,..., aN)`, where some Ti are not fully defined,
2890+
// type `m` directly (undoing eta-expansion of method m) to determine the argument types.
2891+
val ptUnrollingEtaExpansion =
2892+
if (paramsMissingType.nonEmpty && pt != ErrorType) fun.body match {
2893+
case Apply(meth, args) if (vparams corresponds args) { case (p, Ident(name)) => p.name == name case _ => false } =>
2894+
val methArgs = NoSymbol.newSyntheticValueParams(argpts map { case NoType => WildcardType case tp => tp })
2895+
// we're looking for a method (as indicated by FUNmode), so let's make sure our expected type is a MethodType
2896+
val methPt = MethodType(methArgs, respt)
2897+
2898+
silent(_.typed(meth, mode.forFunMode, methPt)) filter (_ => context.undetparams.isEmpty) map { methTyped =>
2899+
// if context.undetparams is not empty, the method was polymorphic,
2900+
// so we need the missing arguments to infer its type. See #871
2901+
val funPt = normalize(methTyped.tpe) baseType FunctionClass(numVparams)
2902+
// println(s"typeUnEtaExpanded $meth : ${methTyped.tpe} --> normalized: $funPt")
2903+
2904+
if (isFunctionType(funPt) && isFullyDefined(funPt)) funPt
2905+
else null
2906+
} orElse { _ => null }
2907+
case _ => null
2908+
} else null
2909+
29092910

2910-
if (expectedFunTypeBeforeEtaExpansion ne NoType) typedFunction(fun, mode, expectedFunTypeBeforeEtaExpansion)
2911+
if (ptUnrollingEtaExpansion ne null) typedFunction(fun, mode, ptUnrollingEtaExpansion)
29112912
else {
29122913
// we ran out of things to try, missing parameter types are an irrevocable error
29132914
var issuedMissingParameterTypeError = false
@@ -2925,24 +2926,24 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
29252926
// thus, its symbol, which serves as the current context.owner, is not the right owner
29262927
// you won't know you're using the wrong owner until lambda lift crashes (unless you know better than to use the wrong owner)
29272928
val outerTyper = newTyper(context.outer)
2928-
val p = fun.vparams.head
2929+
val p = vparams.head
29292930
if (p.tpt.tpe == null) p.tpt setType outerTyper.typedType(p.tpt).tpe
29302931

29312932
outerTyper.synthesizePartialFunction(p.name, p.pos, paramSynthetic = false, fun.body, mode, pt)
29322933

29332934
case _ =>
2934-
val vparamSyms = fun.vparams map { vparam =>
2935+
val vparamSyms = vparams map { vparam =>
29352936
enterSym(context, vparam)
29362937
if (context.retyping) context.scope enter vparam.symbol
29372938
vparam.symbol
29382939
}
2939-
val vparams = fun.vparams mapConserve typedValDef
2940+
val vparamsTyped = vparams mapConserve typedValDef
29402941
val formals = vparamSyms map (_.tpe)
29412942
val body1 = typed(fun.body, respt)
29422943
val restpe = packedType(body1, fun.symbol).deconst.resultType
29432944
val funtpe = phasedAppliedType(FunctionSymbol, formals :+ restpe)
29442945

2945-
treeCopy.Function(fun, vparams, body1) setType funtpe
2946+
treeCopy.Function(fun, vparamsTyped, body1) setType funtpe
29462947
}
29472948
}
29482949
}

test/files/pos/fun_undo_eta.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class Test {
2+
def m(i: Int) = i
3+
4+
def expectWild[A](f: A) = ???
5+
def expectFun[A](f: A => Int) = ???
6+
7+
expectWild((i => m(i))) // manual eta expansion
8+
expectWild(m(_)) // have to undo eta expansion with wildcard expected type
9+
expectFun(m(_)) // have to undo eta expansion with function expected type
10+
}

0 commit comments

Comments
 (0)