diff --git a/include/stdexec/__detail/__affine_on.hpp b/include/stdexec/__detail/__affine_on.hpp index b70ccbf42..3254f9fd0 100644 --- a/include/stdexec/__detail/__affine_on.hpp +++ b/include/stdexec/__detail/__affine_on.hpp @@ -51,9 +51,9 @@ namespace STDEXEC struct affine_on_t { template - constexpr auto operator()(_Sender&& __sndr) const -> __well_formed_sender auto + constexpr auto operator()(_Sender &&__sndr) const -> __well_formed_sender auto { - return __make_sexpr({}, static_cast<_Sender&&>(__sndr)); + return __make_sexpr({}, static_cast<_Sender &&>(__sndr)); } constexpr auto operator()() const noexcept @@ -62,10 +62,10 @@ namespace STDEXEC } template - static constexpr auto transform_sender(set_value_t, _Sender&& __sndr, _Env const & __env) + static constexpr auto transform_sender(set_value_t, _Sender &&__sndr, _Env const &__env) { static_assert(sender_expr_for<_Sender, affine_on_t>); - auto& [__tag, __ign, __child] = __sndr; + auto &[__tag, __ign, __child] = __sndr; using __child_t = decltype(__child); using __cv_child_t = __copy_cvref_t<_Sender, __child_t>; using __sched_t = __call_result_or_t, _Env const &>; @@ -116,14 +116,26 @@ namespace STDEXEC namespace __affine_on { + template struct __attrs { - template - constexpr auto query(__get_completion_behavior_t<_Tag>) const noexcept + template + requires __queryable_with<_Attrs, __get_completion_behavior_t<_Tag>, _Env const &...> + constexpr auto query(__get_completion_behavior_t<_Tag>, _Env const &...) const noexcept { - // FUTURE: when the child sender completes inline *and* the current scheduler also - // completes inline, we can return "inline" here instead of "__asynchronous_affine". - return __completion_behavior::__asynchronous_affine; + using __behavior_t = + __query_result_t<_Attrs, __get_completion_behavior_t<_Tag>, _Env const &...>; + + // When the child sender completes inline, we can return "inline" here instead of + // "__asynchronous_affine". + if constexpr (__behavior_t::value == __completion_behavior::__inline_completion) + { + return __completion_behavior::__inline_completion; + } + else + { + return __completion_behavior::__asynchronous_affine; + } } }; } // namespace __affine_on @@ -132,9 +144,9 @@ namespace STDEXEC struct __sexpr_impl : __sexpr_defaults { static constexpr auto __get_attrs = // - [](__ignore, __ignore, __ignore) noexcept + [](__ignore, __ignore, _Child const &) noexcept { - return __affine_on::__attrs{}; + return __affine_on::__attrs>{}; }; }; } // namespace STDEXEC diff --git a/include/stdexec/__detail/__as_awaitable.hpp b/include/stdexec/__detail/__as_awaitable.hpp index 8a075381a..2270d39c4 100644 --- a/include/stdexec/__detail/__as_awaitable.hpp +++ b/include/stdexec/__detail/__as_awaitable.hpp @@ -84,13 +84,14 @@ namespace STDEXEC using __expected_t = std::variant, std::exception_ptr>; - // Helper to cast a coroutine_handle to coroutine_handle<_Promise> - template - constexpr auto __coroutine_handle_cast(__std::coroutine_handle<> __hcoro) noexcept - -> __std::coroutine_handle<_Promise> - { - return __std::coroutine_handle<_Promise>::from_address(__hcoro.address()); - } + template + concept __completes_inline_for = __never_sends<_Tag, _Sender, _Env...> + || STDEXEC::__completes_inline<_Tag, env_of_t<_Sender>, _Env...>; + + template + concept __completes_inline = __completes_inline_for + && __completes_inline_for + && __completes_inline_for; template struct __receiver_base @@ -98,17 +99,15 @@ namespace STDEXEC using receiver_concept = receiver_t; template - requires __std::constructible_from<__value_or_void_t<_Value>, _Us...> void set_value(_Us&&... __us) noexcept { STDEXEC_TRY { - __result_->template emplace<1>(static_cast<_Us&&>(__us)...); - __continuation_.resume(); + __result_.template emplace<1>(static_cast<_Us&&>(__us)...); } STDEXEC_CATCH_ALL { - STDEXEC::set_error(static_cast<__receiver_base&&>(*this), std::current_exception()); + __result_.template emplace<2>(std::current_exception()); } } @@ -116,91 +115,183 @@ namespace STDEXEC void set_error(_Error&& __err) noexcept { if constexpr (__decays_to<_Error, std::exception_ptr>) - __result_->template emplace<2>(static_cast<_Error&&>(__err)); + __result_.template emplace<2>(static_cast<_Error&&>(__err)); else if constexpr (__decays_to<_Error, std::error_code>) - __result_->template emplace<2>(std::make_exception_ptr(std::system_error(__err))); + __result_.template emplace<2>(std::make_exception_ptr(std::system_error(__err))); else - __result_->template emplace<2>(std::make_exception_ptr(static_cast<_Error&&>(__err))); - __continuation_.resume(); + __result_.template emplace<2>(std::make_exception_ptr(static_cast<_Error&&>(__err))); } - __expected_t<_Value>* __result_; - __std::coroutine_handle<> __continuation_; + __expected_t<_Value>& __result_; }; template - struct __receiver : __receiver_base<_Value> + struct __sync_receiver : __receiver_base<_Value> { - constexpr void set_stopped() noexcept + constexpr explicit __sync_receiver(__expected_t<_Value>& __result, + __std::coroutine_handle<_Promise> __continuation) noexcept + : __receiver_base<_Value>{__result} + , __continuation_{__continuation} + {} + + void set_stopped() noexcept { - auto __continuation = __coroutine_handle_cast<_Promise>(this->__continuation_); - // Do not use type deduction here so that we perform any conversions necessary on - // the stopped continuation: - __std::coroutine_handle<> __on_stopped = __continuation.promise().unhandled_stopped(); - __on_stopped.resume(); + // no-op: the __result_ variant will remain engaged with the monostate + // alternative, which signals that the operation was stopped. } // Forward get_env query to the coroutine promise constexpr auto get_env() const noexcept -> env_of_t<_Promise&> { - auto const __continuation = __coroutine_handle_cast<_Promise>(this->__continuation_); - return STDEXEC::get_env(__continuation.promise()); + return STDEXEC::get_env(__continuation_.promise()); + } + + __std::coroutine_handle<_Promise> __continuation_; + }; + + // The receiver type used to connect to senders that could complete asynchronously. + template + struct __async_receiver : __sync_receiver<_Promise, _Value> + { + constexpr explicit __async_receiver(__expected_t<_Value>& __result, + __std::coroutine_handle<_Promise> __continuation) noexcept + : __sync_receiver<_Promise, _Value>{__result, __continuation} + {} + + template + void set_value(_Us&&... __us) noexcept + { + this->__sync_receiver<_Promise, _Value>::set_value(static_cast<_Us&&>(__us)...); + this->__continuation_.resume(); + } + + template + void set_error(_Error&& __err) noexcept + { + this->__sync_receiver<_Promise, _Value>::set_error(static_cast<_Error&&>(__err)); + this->__continuation_.resume(); + } + + constexpr void set_stopped() noexcept + { + STDEXEC_TRY + { + // Resuming the stopped continuation unwinds the coroutine stack until we reach + // a promise that can handle the stopped signal. The coroutine referred to by + // __continuation_ will never be resumed. + __std::coroutine_handle<> __on_stopped = + this->__continuation_.promise().unhandled_stopped(); + __on_stopped.resume(); + } + STDEXEC_CATCH_ALL + { + this->__result_.template emplace<2>(std::current_exception()); + this->__continuation_.resume(); + } } }; template - using __receiver_t = __receiver<_Promise, __detail::__value_t<_Sender, _Promise>>; + using __sync_receiver_t = __sync_receiver<_Promise, __detail::__value_t<_Sender, _Promise>>; + + template + using __async_receiver_t = __async_receiver<_Promise, __detail::__value_t<_Sender, _Promise>>; template struct __sender_awaitable_base { - [[nodiscard]] - constexpr auto await_ready() const noexcept -> bool + static constexpr auto await_ready() noexcept -> bool { return false; } constexpr auto await_resume() -> _Value { - switch (__result_.index()) + // If the operation completed with set_stopped (as denoted by the monostate + // alternative being active), we should not be resuming this coroutine at all. + STDEXEC_ASSERT(__result_.index() != 0); + if (__result_.index() == 2) { - case 0: // receiver contract not satisfied - STDEXEC_ASSERT(false && +"_Should never get here" == nullptr); - break; - case 1: // set_value - if constexpr (!__same_as<_Value, void>) - return static_cast<_Value&&>(std::get<1>(__result_)); - else - return; - case 2: // set_error - std::rethrow_exception(std::get<2>(__result_)); + // The operation completed with set_error, so we need to rethrow the exception. + std::rethrow_exception(std::move(std::get<2>(__result_))); } - std::terminate(); + // The operation completed with set_value, so we can just return the value, which + // may be void. + return static_cast>(std::get<1>(__result_)); } protected: - __expected_t<_Value> __result_; + __expected_t<_Value> __result_{}; }; + ////////////////////////////////////////////////////////////////////////////////////// + // __sender_awaitable: awaitable type returned by as_awaitable when given a sender + // that does not have an as_awaitable member function template struct __sender_awaitable : __sender_awaitable_base<__detail::__value_t<_Sender, _Promise>> { - constexpr __sender_awaitable(_Sender&& sndr, __std::coroutine_handle<_Promise> __hcoro) - noexcept(__nothrow_connectable<_Sender, __receiver>) - : __op_state_(connect(static_cast<_Sender&&>(sndr), - __receiver{ - {&this->__result_, __hcoro} - })) + constexpr explicit __sender_awaitable(_Sender&& __sndr, + __std::coroutine_handle<_Promise> __hcoro) + noexcept(__nothrow_connectable<_Sender, __receiver_t>) + : __opstate_(STDEXEC::connect(static_cast<_Sender&&>(__sndr), + __receiver_t(this->__result_, __hcoro))) {} constexpr void await_suspend(__std::coroutine_handle<_Promise>) noexcept { - STDEXEC::start(__op_state_); + STDEXEC::start(__opstate_); + } + + private: + using __receiver_t = __async_receiver_t<_Sender, _Promise>; + connect_result_t<_Sender, __receiver_t> __opstate_; + }; + + // When the sender is known to complete inline, we can connect and start the operation + // in await_suspend. + template + requires __completes_inline<_Sender, env_of_t<_Promise&>> + struct __sender_awaitable<_Promise, _Sender> + : __sender_awaitable_base<__detail::__value_t<_Sender, _Promise>> + { + constexpr explicit __sender_awaitable(_Sender&& sndr, __ignore) + noexcept(__nothrow_move_constructible<_Sender>) + : __sndr_(static_cast<_Sender&&>(sndr)) + {} + + bool await_suspend(__std::coroutine_handle<_Promise> __hcoro) + { + { + auto __opstate = STDEXEC::connect(static_cast<_Sender&&>(__sndr_), + __receiver_t(this->__result_, __hcoro)); + // The following call to start will complete synchronously, writing its result + // into the __result_ variant. + STDEXEC::start(__opstate); + } + + if (this->__result_.index() == 0) + { + // The operation completed with set_stopped, so we need to call + // unhandled_stopped() on the promise to propagate the stop signal. That will + // result in the coroutine being torn down, so beware. We then resume the + // returned coroutine handle (which may be a noop_coroutine). + __std::coroutine_handle<> __on_stopped = __hcoro.promise().unhandled_stopped(); + __on_stopped.resume(); + + // By returning true, we indicate that the coroutine should not be resumed + // (because it no longer exists). + return true; + } + + // The operation completed with set_value or set_error, so we can just resume the + // current coroutine. await_resume with either return the value or throw as + // appropriate. + return false; } private: - using __receiver = __receiver_t<_Sender, _Promise>; - connect_result_t<_Sender, __receiver> __op_state_; + using __receiver_t = __sync_receiver_t<_Sender, _Promise>; + _Sender __sndr_; }; template @@ -211,7 +302,6 @@ namespace STDEXEC template concept __awaitable_adapted_sender = sender_in<_Sender, env_of_t<_Promise&>> && __minvocable_q<__detail::__value_t, _Sender, _Promise> - && sender_to<_Sender, __receiver_t<_Sender, _Promise>> && requires(_Promise& __promise) { { __promise.unhandled_stopped() diff --git a/include/stdexec/__detail/__task.hpp b/include/stdexec/__detail/__task.hpp index 56cdc82ab..e502743d0 100644 --- a/include/stdexec/__detail/__task.hpp +++ b/include/stdexec/__detail/__task.hpp @@ -145,7 +145,7 @@ namespace STDEXEC template concept __has_allocator_compatible_with = requires(_Rcvr& __rcvr) { - _Alloc(get_allocator(get_env(__rcvr))); + _Alloc(STDEXEC::get_allocator(STDEXEC::get_env(__rcvr))); } || std::default_initializable<_Alloc>; } // namespace __task @@ -213,8 +213,7 @@ namespace STDEXEC private: using __on_stopped_t = __task::__on_stopped; - using __error_variant_t = - __error_types_t, __q1<__decay_t>>; + using __error_variant_t = __error_types_t, __q1<__decay_t>>; using __completions_t = __concat_completion_signatures_t< completion_signatures<__detail::__single_value_sig_t<_Ty>, set_stopped_t()>, @@ -241,10 +240,7 @@ namespace STDEXEC { constexpr explicit __opstate_base(scheduler_type __sched) noexcept : __sch_(std::move(__sched)) - { - // Initialize the errors variant to monostate, the "no error" state: - __errors_.template emplace<0>(); - } + {} virtual void __completed() noexcept = 0; virtual void __canceled() noexcept = 0; @@ -265,7 +261,7 @@ namespace STDEXEC // task::__opstate template template - struct STDEXEC_ATTRIBUTE(empty_bases) task<_Ty, _Env>::__opstate + struct STDEXEC_ATTRIBUTE(empty_bases) task<_Ty, _Env>::__opstate final : __opstate_base , __if_c<__needs_stop_callback<_Rcvr>, __manual_lifetime<__stop_callback_t<_Rcvr>>, __empty> { @@ -290,7 +286,7 @@ namespace STDEXEC } else { - __coro_.promise().__stop_.template emplace<1>(get_stop_token(get_env(__rcvr_))); + __coro_.promise().__stop_.template emplace<1>(get_stop_token(STDEXEC::get_env(__rcvr_))); } } @@ -371,9 +367,7 @@ namespace STDEXEC __stop_callback().__destroy(); } - std::printf("opstate completed, &__errors_ = %p\n", static_cast(&this->__errors_)); - - if (this->__errors_.index() != 0) + if (this->__errors_.index() != __variant_npos) { std::exchange(__coro_, {}).destroy(); __visit(STDEXEC::set_error, std::move(this->__errors_), static_cast<_Rcvr&&>(__rcvr_)); diff --git a/include/stdexec/__detail/__transform_sender.hpp b/include/stdexec/__detail/__transform_sender.hpp index ec4fc8f19..49799d026 100644 --- a/include/stdexec/__detail/__transform_sender.hpp +++ b/include/stdexec/__detail/__transform_sender.hpp @@ -72,7 +72,7 @@ namespace STDEXEC } template - STDEXEC_ATTRIBUTE(nodiscard, host, device) + STDEXEC_ATTRIBUTE(nodiscard, host, device, always_inline) constexpr auto operator()(_Sndr&& __sndr) const noexcept(__nothrow_move_constructible<_Sndr>) -> _Sndr { @@ -118,7 +118,7 @@ namespace STDEXEC struct __compose { template - STDEXEC_ATTRIBUTE(nodiscard, host, device) + STDEXEC_ATTRIBUTE(nodiscard, host, device, always_inline) constexpr auto operator()(_Sndr&& __sndr, _Env const & __env) const noexcept(noexcept(_Fn1()(_Fn2()(static_cast<_Sndr&&>(__sndr), __env), __env))) -> decltype(_Fn1()(_Fn2()(static_cast<_Sndr&&>(__sndr), __env), __env)) @@ -139,7 +139,7 @@ namespace STDEXEC public: // NOT TO SPEC: template - STDEXEC_ATTRIBUTE(nodiscard, host, device) + STDEXEC_ATTRIBUTE(nodiscard, host, device, always_inline) constexpr auto operator()(_Sndr&& __sndr) const noexcept(__nothrow_move_constructible<_Sndr>) // -> _Sndr @@ -148,7 +148,7 @@ namespace STDEXEC } template {}> - STDEXEC_ATTRIBUTE(nodiscard, host, device) + STDEXEC_ATTRIBUTE(nodiscard, host, device, always_inline) constexpr auto operator()(_Sndr && __sndr, _Env const & __env) const noexcept(noexcept(_ImplFn(static_cast<_Sndr&&>(__sndr), __env))) -> decltype(_ImplFn(static_cast<_Sndr&&>(__sndr), __env)) diff --git a/include/stdexec/__detail/__type_traits.hpp b/include/stdexec/__detail/__type_traits.hpp index bb7ad5c9a..c4fc85cec 100644 --- a/include/stdexec/__detail/__type_traits.hpp +++ b/include/stdexec/__detail/__type_traits.hpp @@ -18,7 +18,7 @@ #include "__config.hpp" #include // IWYU pragma: keep for std::terminate -#include // IWYU pragma: keep +#include // IWYU pragma: export #include // IWYU pragma: keep namespace STDEXEC diff --git a/include/stdexec/__detail/__utility.hpp b/include/stdexec/__detail/__utility.hpp index 4959f381b..3cab5f038 100644 --- a/include/stdexec/__detail/__utility.hpp +++ b/include/stdexec/__detail/__utility.hpp @@ -351,14 +351,7 @@ namespace STDEXEC [[noreturn]] inline void unreachable() { - // Uses compiler specific extensions if possible. - // Even if no extension is used, undefined behavior is still raised by - // an empty function body and the noreturn attribute. -# if STDEXEC_MSVC() - __assume(false); // MSVC -# else - __builtin_unreachable(); // everybody else -# endif + STDEXEC_UNREACHABLE(); } #endif } // namespace __std diff --git a/test/stdexec/types/test_task.cpp b/test/stdexec/types/test_task.cpp index 15c3a6014..cb274886b 100644 --- a/test/stdexec/types/test_task.cpp +++ b/test/stdexec/types/test_task.cpp @@ -244,6 +244,31 @@ namespace CHECK(i == 42); } + auto nested() -> ex::task + { + auto sched = co_await ex::read_env(ex::get_scheduler); + static_assert(std::same_as); + co_await ex::schedule(sched); + co_return 42; + } + + auto test_task_awaits_inline_sndr_without_stack_overflow() -> ex::task + { + int result = co_await nested(); + for (int i = 0; i < 1'000'000; ++i) + { + result += co_await ex::just(42); + } + co_return result; + } + + TEST_CASE("test task can await a just_int sender without stack overflow", "[types][task]") + { + auto t = test_task_awaits_inline_sndr_without_stack_overflow(); + auto [i] = ex::sync_wait(std::move(t)).value(); + CHECK(i == 42'000'042); + } + // FUTURE TODO: add support so that `co_await sndr` can return a reference. // auto test_task_awaits_just_ref_sender() -> ex::task { // int value = 42;