Skip to content

Commit 08ea73a

Browse files
authored
Merge pull request scala#6003 from szeiger/fix/pf-overloads2
Two fixes for overload resolution for PartialFunctions
2 parents 64cbe13 + 56040a3 commit 08ea73a

File tree

4 files changed

+46
-8
lines changed

4 files changed

+46
-8
lines changed

src/compiler/scala/tools/nsc/typechecker/Typers.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3341,7 +3341,7 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
33413341

33423342
def funArgTypes(tpAlts: List[(Type, Symbol)]) = tpAlts.map { case (tp, alt) =>
33433343
val relTp = tp.asSeenFrom(pre, alt.owner)
3344-
val argTps = functionOrSamArgTypes(relTp)
3344+
val argTps = functionOrPfOrSamArgTypes(relTp)
33453345
//println(s"funArgTypes $argTps from $relTp")
33463346
argTps.map(approximateAbstracts)
33473347
}
@@ -3350,6 +3350,10 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
33503350
try functionType(funArgTypes(argTpWithAlt).transpose.map(lub), WildcardType)
33513351
catch { case _: IllegalArgumentException => WildcardType }
33523352

3353+
def partialFunctionProto(argTpWithAlt: List[(Type, Symbol)]): Type =
3354+
try appliedType(PartialFunctionClass, funArgTypes(argTpWithAlt).transpose.map(lub) :+ WildcardType)
3355+
catch { case _: IllegalArgumentException => WildcardType }
3356+
33533357
// To propagate as much information as possible to typedFunction, which uses the expected type to
33543358
// infer missing parameter types for Function trees that we're typing as arguments here,
33553359
// we expand the parameter types for all alternatives to the expected argument length,
@@ -3360,7 +3364,7 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
33603364
// and lubbing the argument types (we treat SAM and FunctionN types equally, but non-function arguments
33613365
// do not receive special treatment: they are typed under WildcardType.)
33623366
val altArgPts =
3363-
if (settings.isScala212 && args.exists(treeInfo.isFunctionMissingParamType))
3367+
if (settings.isScala212 && args.exists(t => treeInfo.isFunctionMissingParamType(t) || treeInfo.isPartialFunctionMissingParamType(t)))
33643368
try alts.map(alt => formalTypes(alt.info.paramTypes, argslen).map(ft => (ft, alt))).transpose // do least amount of work up front
33653369
catch { case _: IllegalArgumentException => args.map(_ => Nil) } // fail safe in case formalTypes fails to align to argslen
33663370
else args.map(_ => Nil) // will type under argPt == WildcardType
@@ -3375,8 +3379,12 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
33753379
// the overloaded type into a single function type from which `typedFunction`
33763380
// can derive the argument type for `x` in the function literal above
33773381
val argPt =
3378-
if (argPtAlts.nonEmpty && treeInfo.isFunctionMissingParamType(tree)) functionProto(argPtAlts)
3379-
else WildcardType
3382+
if (argPtAlts.isEmpty) WildcardType
3383+
else if (treeInfo.isFunctionMissingParamType(tree)) functionProto(argPtAlts)
3384+
else if (treeInfo.isPartialFunctionMissingParamType(tree)) {
3385+
if (argPtAlts.exists(ts => isPartialFunctionType(ts._1))) partialFunctionProto(argPtAlts)
3386+
else functionProto(argPtAlts)
3387+
} else WildcardType
33803388

33813389
val argTyped = typedArg(tree, amode, BYVALmode, argPt)
33823390
(argTyped, argTyped.tpe.deconst)

src/reflect/scala/reflect/internal/Definitions.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -691,11 +691,11 @@ trait Definitions extends api.StandardDefinitions {
691691
}
692692
}
693693

694-
// the argument types expected by the function described by `tp` (a FunctionN or SAM type),
695-
// or `Nil` if `tp` does not represent a function type or SAM (or if it happens to be Function0...)
696-
def functionOrSamArgTypes(tp: Type): List[Type] = {
694+
// the argument types expected by the function described by `tp` (a FunctionN or PartialFunction or SAM type),
695+
// or `Nil` if `tp` does not represent a function type or PartialFunction or SAM (or if it happens to be Function0...)
696+
def functionOrPfOrSamArgTypes(tp: Type): List[Type] = {
697697
val dealiased = tp.dealiasWiden
698-
if (isFunctionTypeDirect(dealiased)) dealiased.typeArgs.init
698+
if (isFunctionTypeDirect(dealiased) || isPartialFunctionType(dealiased)) dealiased.typeArgs.init
699699
else samOf(tp) match {
700700
case samSym if samSym.exists => tp.memberInfo(samSym).paramTypes
701701
case _ => Nil

src/reflect/scala/reflect/internal/TreeInfo.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,10 @@ abstract class TreeInfo {
265265

266266
def isFunctionMissingParamType(tree: Tree): Boolean = tree match {
267267
case Function(vparams, _) => vparams.exists(_.tpt.isEmpty)
268+
case _ => false
269+
}
270+
271+
def isPartialFunctionMissingParamType(tree: Tree): Boolean = tree match {
268272
case Match(EmptyTree, _) => true
269273
case _ => false
270274
}

test/files/run/InferOverloadedPartialFunction.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,30 @@ object Test extends App {
4444
def h[R](pf: Function2[Int, String, R]): Int = 1
4545
def h[R](pf: PartialFunction[(Double, Double), R]): Int = 2
4646
assert(h { case (a: Double, b: Double) => 42: Int } == 2)
47+
48+
val xs = new SortedMap
49+
assert(xs.collectF { kv => 1 } == 0)
50+
assert(xs.collectF { case (k, v) => 1 } == 0)
51+
assert(xs.collectF { case (k, v) => (1, 1) } == 2)
52+
assert(xs.collect { case (k, v) => 1 } == 0)
53+
assert(xs.collect { case (k, v) => (1, 1) } == 1)
54+
55+
val ys = new SortedMapMixed
56+
assert(ys.collect { kv => 1 } == 0)
57+
assert(ys.collect { kv => (1, 1) } == 0)
58+
assert(ys.collect { case (k, v) => 1 } == 1) // could be 0 with the extra work in https://github.com/scala/scala/pull/5975/commits/3c95dac0dcbb0c8eb4686264026ad9c86b2022de
59+
assert(ys.collect { case (k, v) => (1, 1) } == 2)
60+
}
61+
62+
class SortedMap {
63+
def collect[B](pf: PartialFunction[(String, Int), B]): Int = 0
64+
def collect[K2 : Ordering, V2](pf: PartialFunction[(String, Int), (K2, V2)]): Int = 1
65+
def collectF[B](pf: Function1[(String, Int), B]): Int = if(pf.isInstanceOf[PartialFunction[_, _]]) 1 else 0
66+
def collectF[K2 : Ordering, V2](pf: Function1[(String, Int), (K2, V2)]): Int = if(pf.isInstanceOf[PartialFunction[_, _]]) 3 else 2
67+
}
68+
69+
class SortedMapMixed {
70+
type PF[-A, +B] = PartialFunction[A, B]
71+
def collect[B](pf: Function1[(String, Int), B]): Int = if(pf.isInstanceOf[PartialFunction[_, _]]) 1 else 0
72+
def collect[K2 : Ordering, V2](pf: PF[(String, Int), (K2, V2)]): Int = 2
4773
}

0 commit comments

Comments
 (0)