diff --git a/src/library/scala/concurrent/Future.scala b/src/library/scala/concurrent/Future.scala index def086bc038d..71b6fe71d2a6 100644 --- a/src/library/scala/concurrent/Future.scala +++ b/src/library/scala/concurrent/Future.scala @@ -82,9 +82,29 @@ import language.higherKinds * {{{ * f flatMap { (x: Int) => g map { (y: Int) => x + y } } * }}} + * + * @define callbackInContext + * The provided callback always runs in the provided implicit + *`ExecutionContext`, though there is no guarantee that the + * `execute()` method on the `ExecutionContext` will be called once + * per callback or that `execute()` will be called in the current + * thread. That is, the implementation may run multiple callbacks + * in a batch within a single `execute()` and it may run + * `execute()` either immediately or asynchronously. */ trait Future[+T] extends Awaitable[T] { + // The executor within the lexical scope + // of the Future trait. Note that this will + // (modulo bugs) _never_ execute a callback + // other than those below in this same file. + // As a nice side benefit, having this implicit + // here forces an ambiguity in those methods + // that also have an executor parameter, which + // keeps us from accidentally forgetting to use + // the executor parameter. + private implicit def internalExecutor: ExecutionContext = Future.InternalCallbackExecutor + /* Callbacks */ /** When this future is completed successfully (i.e. with a value), @@ -95,11 +115,12 @@ trait Future[+T] extends Awaitable[T] { * this will either be applied immediately or be scheduled asynchronously. * * $multipleCallbacks + * $callbackInContext */ - def onSuccess[U](pf: PartialFunction[T, U]): this.type = onComplete { + def onSuccess[U](pf: PartialFunction[T, U])(implicit executor: ExecutionContext): this.type = onComplete { case Left(t) => // do nothing case Right(v) => if (pf isDefinedAt v) pf(v) else { /*do nothing*/ } - } + } (executor) /** When this future is completed with a failure (i.e. with a throwable), * apply the provided callback to the throwable. @@ -112,11 +133,12 @@ trait Future[+T] extends Awaitable[T] { * Will not be called in case that the future is completed with a value. * * $multipleCallbacks + * $callbackInContext */ - def onFailure[U](callback: PartialFunction[Throwable, U]): this.type = onComplete { + def onFailure[U](callback: PartialFunction[Throwable, U])(implicit executor: ExecutionContext): this.type = onComplete { case Left(t) => if (isFutureThrowable(t) && callback.isDefinedAt(t)) callback(t) else { /*do nothing*/ } case Right(v) => // do nothing - } + } (executor) /** When this future is completed, either through an exception, or a value, * apply the provided function. @@ -125,16 +147,12 @@ trait Future[+T] extends Awaitable[T] { * this will either be applied immediately or be scheduled asynchronously. * * $multipleCallbacks + * $callbackInContext */ - def onComplete[U](func: Either[Throwable, T] => U): this.type - + def onComplete[U](func: Either[Throwable, T] => U)(implicit executor: ExecutionContext): this.type /* Miscellaneous */ - /** Creates a new promise. - */ - protected def newPromise[S]: Promise[S] - /** Returns whether the future has already been completed with * a value or an exception. * @@ -169,7 +187,7 @@ trait Future[+T] extends Awaitable[T] { * and throws a corresponding exception if the original future fails. */ def failed: Future[Throwable] = { - val p = newPromise[Throwable] + val p = Promise[Throwable] onComplete { case Left(t) => p success t @@ -186,10 +204,10 @@ trait Future[+T] extends Awaitable[T] { * * Will not be called if the future fails. */ - def foreach[U](f: T => U): Unit = onComplete { + def foreach[U](f: T => U)(implicit executor: ExecutionContext): Unit = onComplete { case Right(r) => f(r) case Left(_) => // do nothing - } + } (executor) /** Creates a new future by applying a function to the successful result of * this future. If this future is completed with an exception then the new @@ -197,8 +215,8 @@ trait Future[+T] extends Awaitable[T] { * * $forComprehensionExample */ - def map[S](f: T => S): Future[S] = { - val p = newPromise[S] + def map[S](f: T => S)(implicit executor: ExecutionContext): Future[S] = { + val p = Promise[S] onComplete { case Left(t) => p failure t @@ -207,7 +225,7 @@ trait Future[+T] extends Awaitable[T] { catch { case NonFatal(t) => p complete resolver(t) } - } + } (executor) p.future } @@ -219,21 +237,21 @@ trait Future[+T] extends Awaitable[T] { * * $forComprehensionExample */ - def flatMap[S](f: T => Future[S]): Future[S] = { - val p = newPromise[S] + def flatMap[S](f: T => Future[S])(implicit executor: ExecutionContext): Future[S] = { + val p = Promise[S] onComplete { case Left(t) => p failure t case Right(v) => try { - f(v) onComplete { + f(v).onComplete({ case Left(t) => p failure t case Right(v) => p success v - } + })(internalExecutor) } catch { case NonFatal(t) => p complete resolver(t) } - } + } (executor) p.future } @@ -254,8 +272,8 @@ trait Future[+T] extends Awaitable[T] { * await(h, 0) // throw a NoSuchElementException * }}} */ - def filter(pred: T => Boolean): Future[T] = { - val p = newPromise[T] + def filter(pred: T => Boolean)(implicit executor: ExecutionContext): Future[T] = { + val p = Promise[T] onComplete { case Left(t) => p failure t @@ -266,14 +284,14 @@ trait Future[+T] extends Awaitable[T] { } catch { case NonFatal(t) => p complete resolver(t) } - } + } (executor) p.future } /** Used by for-comprehensions. */ - final def withFilter(p: T => Boolean): Future[T] = filter(p) + final def withFilter(p: T => Boolean)(implicit executor: ExecutionContext): Future[T] = filter(p)(executor) // final def withFilter(p: T => Boolean) = new FutureWithFilter[T](this, p) // final class FutureWithFilter[+S](self: Future[S], p: S => Boolean) { @@ -303,8 +321,8 @@ trait Future[+T] extends Awaitable[T] { * await(h, 0) // throw a NoSuchElementException * }}} */ - def collect[S](pf: PartialFunction[T, S]): Future[S] = { - val p = newPromise[S] + def collect[S](pf: PartialFunction[T, S])(implicit executor: ExecutionContext): Future[S] = { + val p = Promise[S] onComplete { case Left(t) => p failure t @@ -315,7 +333,7 @@ trait Future[+T] extends Awaitable[T] { } catch { case NonFatal(t) => p complete resolver(t) } - } + } (executor) p.future } @@ -332,15 +350,15 @@ trait Future[+T] extends Awaitable[T] { * future (6 / 2) recover { case e: ArithmeticException ⇒ 0 } // result: 3 * }}} */ - def recover[U >: T](pf: PartialFunction[Throwable, U]): Future[U] = { - val p = newPromise[U] + def recover[U >: T](pf: PartialFunction[Throwable, U])(implicit executor: ExecutionContext): Future[U] = { + val p = Promise[U] onComplete { case Left(t) if pf isDefinedAt t => try { p success pf(t) } catch { case NonFatal(t) => p complete resolver(t) } case otherwise => p complete otherwise - } + } (executor) p.future } @@ -358,8 +376,8 @@ trait Future[+T] extends Awaitable[T] { * future (6 / 0) recoverWith { case e: ArithmeticException => f } // result: Int.MaxValue * }}} */ - def recoverWith[U >: T](pf: PartialFunction[Throwable, Future[U]]): Future[U] = { - val p = newPromise[U] + def recoverWith[U >: T](pf: PartialFunction[Throwable, Future[U]])(implicit executor: ExecutionContext): Future[U] = { + val p = Promise[U] onComplete { case Left(t) if pf isDefinedAt t => @@ -369,7 +387,7 @@ trait Future[+T] extends Awaitable[T] { case NonFatal(t) => p complete resolver(t) } case otherwise => p complete otherwise - } + } (executor) p.future } @@ -383,16 +401,16 @@ trait Future[+T] extends Awaitable[T] { * with the throwable stored in `that`. */ def zip[U](that: Future[U]): Future[(T, U)] = { - val p = newPromise[(T, U)] + val p = Promise[(T, U)] - this onComplete { + this.onComplete { case Left(t) => p failure t - case Right(r) => that onSuccess { + case Right(r) => that.onSuccess { case r2 => p success ((r, r2)) } } - that onFailure { + that.onFailure { case f => p failure f } @@ -414,7 +432,7 @@ trait Future[+T] extends Awaitable[T] { * }}} */ def fallbackTo[U >: T](that: Future[U]): Future[U] = { - val p = newPromise[U] + val p = Promise[U] onComplete { case r @ Right(_) ⇒ p complete r case _ ⇒ p completeWith that @@ -443,7 +461,7 @@ trait Future[+T] extends Awaitable[T] { if (c.isPrimitive) toBoxed(c) else c } - val p = newPromise[S] + val p = Promise[S] onComplete { case l: Left[Throwable, _] => p complete l.asInstanceOf[Either[Throwable, S]] @@ -481,14 +499,14 @@ trait Future[+T] extends Awaitable[T] { * } * }}} */ - def andThen[U](pf: PartialFunction[Either[Throwable, T], U]): Future[T] = { - val p = newPromise[T] + def andThen[U](pf: PartialFunction[Either[Throwable, T], U])(implicit executor: ExecutionContext): Future[T] = { + val p = Promise[T] onComplete { case r => try if (pf isDefinedAt r) pf(r) finally p complete r - } + } (executor) p.future } @@ -507,7 +525,7 @@ trait Future[+T] extends Awaitable[T] { * }}} */ def either[U >: T](that: Future[U]): Future[U] = { - val p = newPromise[U] + val p = Promise[U] val completePromise: PartialFunction[Either[Throwable, U], _] = { case Left(t) => p tryFailure t @@ -536,10 +554,10 @@ object Future { * * @tparam T the type of the result * @param body the asychronous computation - * @param execctx the execution context on which the future is run + * @param executor the execution context which runs the body * @return the `Future` holding the result of the computation */ - def apply[T](body: =>T)(implicit execctx: ExecutionContext): Future[T] = impl.Future(body) + def apply[T](body: =>T)(implicit executor: ExecutionContext): Future[T] = impl.Future(body) import scala.collection.mutable.Builder import scala.collection.generic.CanBuildFrom @@ -628,6 +646,33 @@ object Future { for (r <- fr; b <- fb) yield (r += b) }.map(_.result) + // This is used to run callbacks which are internal + // to scala.concurrent; our own callbacks are only + // ever used to eventually run another callback, + // and that other callback will have its own + // executor because all callbacks come with + // an executor. Our own callbacks never block + // and have no "expected" exceptions. + // As a result, this executor can do nothing; + // some other executor will always come after + // it (and sometimes one will be before it), + // and those will be performing the "real" + // dispatch to code outside scala.concurrent. + // Because this exists, ExecutionContext.defaultExecutionContext + // isn't instantiated by Future internals, so + // if some code for some reason wants to avoid + // ever starting up the default context, it can do so + // by just not ever using it itself. scala.concurrent + // doesn't need to create defaultExecutionContext as + // a side effect. + private[concurrent] object InternalCallbackExecutor extends ExecutionContext { + def execute(runnable: Runnable): Unit = + runnable.run() + def internalBlockingCall[T](awaitable: Awaitable[T], atMost: Duration): T = + throw new IllegalStateException("bug in scala.concurrent, called blocking() from internal callback") + def reportFailure(t: Throwable): Unit = + throw new IllegalStateException("problem in scala.concurrent internal callback", t) + } } diff --git a/src/library/scala/concurrent/Promise.scala b/src/library/scala/concurrent/Promise.scala index f7ec0714cfb0..c1c5d30b552b 100644 --- a/src/library/scala/concurrent/Promise.scala +++ b/src/library/scala/concurrent/Promise.scala @@ -25,6 +25,11 @@ package scala.concurrent */ trait Promise[T] { + // used for internal callbacks defined in + // the lexical scope of this trait; + // _never_ for application callbacks. + private implicit def internalExecutor: ExecutionContext = Future.InternalCallbackExecutor + /** Future containing the value of this promise. */ def future: Future[T] @@ -106,26 +111,23 @@ object Promise { /** Creates a promise object which can be completed with a value. * * @tparam T the type of the value in the promise - * @param execctx the execution context on which the promise is created on * @return the newly created `Promise` object */ - def apply[T]()(implicit executor: ExecutionContext): Promise[T] = new impl.Promise.DefaultPromise[T]() + def apply[T](): Promise[T] = new impl.Promise.DefaultPromise[T]() /** Creates an already completed Promise with the specified exception. * * @tparam T the type of the value in the promise - * @param execctx the execution context on which the promise is created on * @return the newly created `Promise` object */ - def failed[T](exception: Throwable)(implicit executor: ExecutionContext): Promise[T] = new impl.Promise.KeptPromise[T](Left(exception)) + def failed[T](exception: Throwable): Promise[T] = new impl.Promise.KeptPromise[T](Left(exception)) /** Creates an already completed Promise with the specified result. * * @tparam T the type of the value in the promise - * @param execctx the execution context on which the promise is created on * @return the newly created `Promise` object */ - def successful[T](result: T)(implicit executor: ExecutionContext): Promise[T] = new impl.Promise.KeptPromise[T](Right(result)) + def successful[T](result: T): Promise[T] = new impl.Promise.KeptPromise[T](Right(result)) } diff --git a/src/library/scala/concurrent/impl/Future.scala b/src/library/scala/concurrent/impl/Future.scala index 20d4122e8f90..957f9cf1ddc4 100644 --- a/src/library/scala/concurrent/impl/Future.scala +++ b/src/library/scala/concurrent/impl/Future.scala @@ -17,8 +17,6 @@ import scala.collection.mutable.Stack private[concurrent] trait Future[+T] extends scala.concurrent.Future[T] with Awaitable[T] { - implicit def executor: ExecutionContext - } private[concurrent] object Future { diff --git a/src/library/scala/concurrent/impl/Promise.scala b/src/library/scala/concurrent/impl/Promise.scala index 54be848f14f8..0beb1f5309b0 100644 --- a/src/library/scala/concurrent/impl/Promise.scala +++ b/src/library/scala/concurrent/impl/Promise.scala @@ -29,11 +29,9 @@ private[concurrent] trait Promise[T] extends scala.concurrent.Promise[T] with Fu object Promise { /** Default promise implementation. */ - class DefaultPromise[T](implicit val executor: ExecutionContext) extends AbstractPromise with Promise[T] { self => + class DefaultPromise[T] extends AbstractPromise with Promise[T] { self => updater.set(this, Nil) // Start at "No callbacks" //FIXME switch to Unsafe instead of ARFU - def newPromise[S]: scala.concurrent.Promise[S] = new Promise.DefaultPromise() - protected final def tryAwait(atMost: Duration): Boolean = { @tailrec def awaitUnsafe(waitTimeNanos: Long): Boolean = { @@ -106,22 +104,30 @@ object Promise { }) match { case null => false case cs if cs.isEmpty => true - case cs => Future.dispatchFuture(executor, () => cs.foreach(f => notifyCompleted(f, resolved))); true + // this assumes that bindDispatch() was called to create f, + // so it will go via dispatchFuture and notifyCompleted + case cs => cs.foreach(f => f(resolved)); true } } - def onComplete[U](func: Either[Throwable, T] => U): this.type = { + private def bindDispatch(func: Either[Throwable, T] => Any)(implicit executor: ExecutionContext): Either[Throwable, T] => Unit = { + either: Either[Throwable, T] => + Future.dispatchFuture(executor, () => notifyCompleted(func, either)) + } + + def onComplete[U](func: Either[Throwable, T] => U)(implicit executor: ExecutionContext): this.type = { + val bound = bindDispatch(func) @tailrec //Tries to add the callback, if already completed, it dispatches the callback to be executed def dispatchOrAddCallback(): Unit = getState match { - case r: Either[_, _] => Future.dispatchFuture(executor, () => notifyCompleted(func, r.asInstanceOf[Either[Throwable, T]])) - case listeners: List[_] => if (updateState(listeners, func :: listeners)) () else dispatchOrAddCallback() + case r: Either[_, _] => bound(r.asInstanceOf[Either[Throwable, T]]) + case listeners: List[_] => if (updateState(listeners, bound :: listeners)) () else dispatchOrAddCallback() } dispatchOrAddCallback() this } - private final def notifyCompleted(func: Either[Throwable, T] => Any, result: Either[Throwable, T]) { + private final def notifyCompleted(func: Either[Throwable, T] => Any, result: Either[Throwable, T])(implicit executor: ExecutionContext) { try { func(result) } catch { @@ -134,17 +140,15 @@ object Promise { * * Useful in Future-composition when a value to contribute is already available. */ - final class KeptPromise[T](suppliedValue: Either[Throwable, T])(implicit val executor: ExecutionContext) extends Promise[T] { + final class KeptPromise[T](suppliedValue: Either[Throwable, T]) extends Promise[T] { val value = Some(resolveEither(suppliedValue)) override def isCompleted(): Boolean = true - def newPromise[S]: scala.concurrent.Promise[S] = new Promise.DefaultPromise() - def tryComplete(value: Either[Throwable, T]): Boolean = false - def onComplete[U](func: Either[Throwable, T] => U): this.type = { + def onComplete[U](func: Either[Throwable, T] => U)(implicit executor: ExecutionContext): this.type = { val completedAs = value.get Future.dispatchFuture(executor, () => func(completedAs)) this diff --git a/test/files/jvm/scala-concurrent-tck.scala b/test/files/jvm/scala-concurrent-tck.scala index fce1a25bb616..37f1d21e0526 100644 --- a/test/files/jvm/scala-concurrent-tck.scala +++ b/test/files/jvm/scala-concurrent-tck.scala @@ -732,6 +732,126 @@ trait TryEitherExtractor extends TestBase { testLeftMatch() } +trait CustomExecutionContext extends TestBase { + import scala.concurrent.{ ExecutionContext, Awaitable } + + def defaultEC = ExecutionContext.defaultExecutionContext + + val inEC = new java.lang.ThreadLocal[Int]() { + override def initialValue = 0 + } + + def enterEC() = inEC.set(inEC.get + 1) + def leaveEC() = inEC.set(inEC.get - 1) + def assertEC() = assert(inEC.get > 0) + def assertNoEC() = assert(inEC.get == 0) + + class CountingExecutionContext extends ExecutionContext { + val _count = new java.util.concurrent.atomic.AtomicInteger(0) + def count = _count.get + + def delegate = ExecutionContext.defaultExecutionContext + + override def execute(runnable: Runnable) = { + _count.incrementAndGet() + val wrapper = new Runnable() { + override def run() = { + enterEC() + try { + runnable.run() + } finally { + leaveEC() + } + } + } + delegate.execute(wrapper) + } + + override def internalBlockingCall[T](awaitable: Awaitable[T], atMost: Duration): T = + delegate.internalBlockingCall(awaitable, atMost) + + override def reportFailure(t: Throwable): Unit = { + System.err.println("Failure: " + t.getClass.getSimpleName + ": " + t.getMessage) + delegate.reportFailure(t) + } + } + + def countExecs(block: (ExecutionContext) => Unit): Int = { + val context = new CountingExecutionContext() + block(context) + context.count + } + + def testOnSuccessCustomEC(): Unit = { + val count = countExecs { implicit ec => + once { done => + val f = future({ assertNoEC() })(defaultEC) + f onSuccess { + case _ => + assertEC() + done() + } + assertNoEC() + } + } + + // should be onSuccess, but not future body + assert(count == 1) + } + + def testKeptPromiseCustomEC(): Unit = { + val count = countExecs { implicit ec => + once { done => + val f = Promise.successful(10).future + f onSuccess { + case _ => + assertEC() + done() + } + } + } + + // should be onSuccess called once in proper EC + assert(count == 1) + } + + def testCallbackChainCustomEC(): Unit = { + val count = countExecs { implicit ec => + once { done => + assertNoEC() + val addOne = { x: Int => assertEC(); x + 1 } + val f = Promise.successful(10).future + f.map(addOne).filter { x => + assertEC() + x == 11 + } flatMap { x => + Promise.successful(x + 1).future.map(addOne).map(addOne) + } onComplete { + case Left(t) => + try { + throw new AssertionError("error in test: " + t.getMessage, t) + } finally { + done() + } + case Right(x) => + assertEC() + assert(x == 14) + done() + } + assertNoEC() + } + } + + // the count is not defined (other than >=1) + // due to the batching optimizations. + assert(count >= 1) + } + + testOnSuccessCustomEC() + testKeptPromiseCustomEC() + testCallbackChainCustomEC() +} + object Test extends App with FutureCallbacks @@ -740,6 +860,7 @@ with FutureProjections with Promises with Exceptions with TryEitherExtractor +with CustomExecutionContext { System.exit(0) }