LCOV - code coverage report
Current view: top level - capy - when_all.hpp (source / functions) Coverage Total Hit Missed
Test: coverage_remapped.info Lines: 98.0 % 98 96 2
Test Date: 2026-03-04 22:59:25 Functions: 89.0 % 617 549 68

           TLA  Line data    Source code
       1                 : //
       2                 : // Copyright (c) 2026 Steve Gerbino
       3                 : //
       4                 : // Distributed under the Boost Software License, Version 1.0. (See accompanying
       5                 : // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
       6                 : //
       7                 : // Official repository: https://github.com/cppalliance/capy
       8                 : //
       9                 : 
      10                 : #ifndef BOOST_CAPY_WHEN_ALL_HPP
      11                 : #define BOOST_CAPY_WHEN_ALL_HPP
      12                 : 
      13                 : #include <boost/capy/detail/config.hpp>
      14                 : #include <boost/capy/concept/executor.hpp>
      15                 : #include <boost/capy/concept/io_awaitable.hpp>
      16                 : #include <coroutine>
      17                 : #include <boost/capy/ex/io_env.hpp>
      18                 : #include <boost/capy/ex/frame_allocator.hpp>
      19                 : #include <boost/capy/task.hpp>
      20                 : 
      21                 : #include <array>
      22                 : #include <atomic>
      23                 : #include <exception>
      24                 : #include <optional>
      25                 : #include <stop_token>
      26                 : #include <tuple>
      27                 : #include <type_traits>
      28                 : #include <utility>
      29                 : 
      30                 : namespace boost {
      31                 : namespace capy {
      32                 : 
      33                 : namespace detail {
      34                 : 
      35                 : /** Type trait to filter void types from a tuple.
      36                 : 
      37                 :     Void-returning tasks do not contribute a value to the result tuple.
      38                 :     This trait computes the filtered result type.
      39                 : 
      40                 :     Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
      41                 : */
      42                 : template<typename T>
      43                 : using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
      44                 : 
      45                 : template<typename... Ts>
      46                 : using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
      47                 : 
      48                 : /** Holds the result of a single task within when_all.
      49                 : */
      50                 : template<typename T>
      51                 : struct result_holder
      52                 : {
      53                 :     std::optional<T> value_;
      54                 : 
      55 HIT          62 :     void set(T v)
      56                 :     {
      57              62 :         value_ = std::move(v);
      58              62 :     }
      59                 : 
      60              55 :     T get() &&
      61                 :     {
      62              55 :         return std::move(*value_);
      63                 :     }
      64                 : };
      65                 : 
      66                 : /** Specialization for void tasks - no value storage needed.
      67                 : */
      68                 : template<>
      69                 : struct result_holder<void>
      70                 : {
      71                 : };
      72                 : 
      73                 : /** Shared state for when_all operation.
      74                 : 
      75                 :     @tparam Ts The result types of the tasks.
      76                 : */
      77                 : template<typename... Ts>
      78                 : struct when_all_state
      79                 : {
      80                 :     static constexpr std::size_t task_count = sizeof...(Ts);
      81                 : 
      82                 :     // Completion tracking - when_all waits for all children
      83                 :     std::atomic<std::size_t> remaining_count_;
      84                 : 
      85                 :     // Result storage in input order
      86                 :     std::tuple<result_holder<Ts>...> results_;
      87                 : 
      88                 :     // Runner handles - destroyed in await_resume while allocator is valid
      89                 :     std::array<std::coroutine_handle<>, task_count> runner_handles_{};
      90                 : 
      91                 :     // Exception storage - first error wins, others discarded
      92                 :     std::atomic<bool> has_exception_{false};
      93                 :     std::exception_ptr first_exception_;
      94                 : 
      95                 :     // Stop propagation - on error, request stop for siblings
      96                 :     std::stop_source stop_source_;
      97                 : 
      98                 :     // Connects parent's stop_token to our stop_source
      99                 :     struct stop_callback_fn
     100                 :     {
     101                 :         std::stop_source* source_;
     102               4 :         void operator()() const { source_->request_stop(); }
     103                 :     };
     104                 :     using stop_callback_t = std::stop_callback<stop_callback_fn>;
     105                 :     std::optional<stop_callback_t> parent_stop_callback_;
     106                 : 
     107                 :     // Parent resumption
     108                 :     std::coroutine_handle<> continuation_;
     109                 :     io_env const* caller_env_ = nullptr;
     110                 : 
     111              61 :     when_all_state()
     112              61 :         : remaining_count_(task_count)
     113                 :     {
     114              61 :     }
     115                 : 
     116                 :     // Runners self-destruct in final_suspend. No destruction needed here.
     117                 : 
     118                 :     /** Capture an exception (first one wins).
     119                 :     */
     120              20 :     void capture_exception(std::exception_ptr ep)
     121                 :     {
     122              20 :         bool expected = false;
     123              20 :         if(has_exception_.compare_exchange_strong(
     124                 :             expected, true, std::memory_order_relaxed))
     125              17 :             first_exception_ = ep;
     126              20 :     }
     127                 : 
     128                 : };
     129                 : 
     130                 : /** Wrapper coroutine that intercepts task completion.
     131                 : 
     132                 :     This runner awaits its assigned task and stores the result in
     133                 :     the shared state, or captures the exception and requests stop.
     134                 : */
     135                 : template<typename T, typename... Ts>
     136                 : struct when_all_runner
     137                 : {
     138                 :     struct promise_type // : frame_allocating_base  // DISABLED FOR TESTING
     139                 :     {
     140                 :         when_all_state<Ts...>* state_ = nullptr;
     141                 :         io_env env_;
     142                 : 
     143             134 :         when_all_runner get_return_object()
     144                 :         {
     145             134 :             return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
     146                 :         }
     147                 : 
     148             134 :         std::suspend_always initial_suspend() noexcept
     149                 :         {
     150             134 :             return {};
     151                 :         }
     152                 : 
     153             134 :         auto final_suspend() noexcept
     154                 :         {
     155                 :             struct awaiter
     156                 :             {
     157                 :                 promise_type* p_;
     158                 : 
     159             134 :                 bool await_ready() const noexcept
     160                 :                 {
     161             134 :                     return false;
     162                 :                 }
     163                 : 
     164             134 :                 auto await_suspend(std::coroutine_handle<> h) noexcept
     165                 :                 {
     166                 :                     // Extract everything needed before self-destruction.
     167             134 :                     auto* state = p_->state_;
     168             134 :                     auto* counter = &state->remaining_count_;
     169             134 :                     auto* caller_env = state->caller_env_;
     170             134 :                     auto cont = state->continuation_;
     171                 : 
     172             134 :                     h.destroy();
     173                 : 
     174                 :                     // If last runner, dispatch parent for symmetric transfer.
     175             134 :                     auto remaining = counter->fetch_sub(1, std::memory_order_acq_rel);
     176             134 :                     if(remaining == 1)
     177              61 :                         return detail::symmetric_transfer(caller_env->executor.dispatch(cont));
     178              73 :                     return detail::symmetric_transfer(std::noop_coroutine());
     179                 :                 }
     180                 : 
     181 MIS           0 :                 void await_resume() const noexcept
     182                 :                 {
     183               0 :                 }
     184                 :             };
     185 HIT         134 :             return awaiter{this};
     186                 :         }
     187                 : 
     188             114 :         void return_void()
     189                 :         {
     190             114 :         }
     191                 : 
     192              20 :         void unhandled_exception()
     193                 :         {
     194              20 :             state_->capture_exception(std::current_exception());
     195                 :             // Request stop for sibling tasks
     196              20 :             state_->stop_source_.request_stop();
     197              20 :         }
     198                 : 
     199                 :         template<class Awaitable>
     200                 :         struct transform_awaiter
     201                 :         {
     202                 :             std::decay_t<Awaitable> a_;
     203                 :             promise_type* p_;
     204                 : 
     205             134 :             bool await_ready()
     206                 :             {
     207             134 :                 return a_.await_ready();
     208                 :             }
     209                 : 
     210             134 :             decltype(auto) await_resume()
     211                 :             {
     212             134 :                 return a_.await_resume();
     213                 :             }
     214                 : 
     215                 :             template<class Promise>
     216             133 :             auto await_suspend(std::coroutine_handle<Promise> h)
     217                 :             {
     218                 :                 using R = decltype(a_.await_suspend(h, &p_->env_));
     219                 :                 if constexpr (std::is_same_v<R, std::coroutine_handle<>>)
     220             133 :                     return detail::symmetric_transfer(a_.await_suspend(h, &p_->env_));
     221                 :                 else
     222                 :                     return a_.await_suspend(h, &p_->env_);
     223                 :             }
     224                 :         };
     225                 : 
     226                 :         template<class Awaitable>
     227             134 :         auto await_transform(Awaitable&& a)
     228                 :         {
     229                 :             using A = std::decay_t<Awaitable>;
     230                 :             if constexpr (IoAwaitable<A>)
     231                 :             {
     232                 :                 return transform_awaiter<Awaitable>{
     233             268 :                     std::forward<Awaitable>(a), this};
     234                 :             }
     235                 :             else
     236                 :             {
     237                 :                 static_assert(sizeof(A) == 0, "requires IoAwaitable");
     238                 :             }
     239             134 :         }
     240                 :     };
     241                 : 
     242                 :     std::coroutine_handle<promise_type> h_;
     243                 : 
     244             134 :     explicit when_all_runner(std::coroutine_handle<promise_type> h)
     245             134 :         : h_(h)
     246                 :     {
     247             134 :     }
     248                 : 
     249                 :     // Enable move for all clang versions - some versions need it
     250                 :     when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
     251                 : 
     252                 :     // Non-copyable
     253                 :     when_all_runner(when_all_runner const&) = delete;
     254                 :     when_all_runner& operator=(when_all_runner const&) = delete;
     255                 :     when_all_runner& operator=(when_all_runner&&) = delete;
     256                 : 
     257             134 :     auto release() noexcept
     258                 :     {
     259             134 :         return std::exchange(h_, nullptr);
     260                 :     }
     261                 : };
     262                 : 
     263                 : /** Create a runner coroutine for a single awaitable.
     264                 : 
     265                 :     Awaitable is passed directly to ensure proper coroutine frame storage.
     266                 : */
     267                 : template<std::size_t Index, IoAwaitable Awaitable, typename... Ts>
     268                 : when_all_runner<awaitable_result_t<Awaitable>, Ts...>
     269             134 : make_when_all_runner(Awaitable inner, when_all_state<Ts...>* state)
     270                 : {
     271                 :     using T = awaitable_result_t<Awaitable>;
     272                 :     if constexpr (std::is_void_v<T>)
     273                 :     {
     274                 :         co_await std::move(inner);
     275                 :     }
     276                 :     else
     277                 :     {
     278                 :         std::get<Index>(state->results_).set(co_await std::move(inner));
     279                 :     }
     280             268 : }
     281                 : 
     282                 : /** Internal awaitable that launches all runner coroutines and waits.
     283                 : 
     284                 :     This awaitable is used inside the when_all coroutine to handle
     285                 :     the concurrent execution of child awaitables.
     286                 : */
     287                 : template<IoAwaitable... Awaitables>
     288                 : class when_all_launcher
     289                 : {
     290                 :     using state_type = when_all_state<awaitable_result_t<Awaitables>...>;
     291                 : 
     292                 :     std::tuple<Awaitables...>* awaitables_;
     293                 :     state_type* state_;
     294                 : 
     295                 : public:
     296              61 :     when_all_launcher(
     297                 :         std::tuple<Awaitables...>* awaitables,
     298                 :         state_type* state)
     299              61 :         : awaitables_(awaitables)
     300              61 :         , state_(state)
     301                 :     {
     302              61 :     }
     303                 : 
     304              61 :     bool await_ready() const noexcept
     305                 :     {
     306              61 :         return sizeof...(Awaitables) == 0;
     307                 :     }
     308                 : 
     309              61 :     std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation, io_env const* caller_env)
     310                 :     {
     311              61 :         state_->continuation_ = continuation;
     312              61 :         state_->caller_env_ = caller_env;
     313                 : 
     314                 :         // Forward parent's stop requests to children
     315              61 :         if(caller_env->stop_token.stop_possible())
     316                 :         {
     317              16 :             state_->parent_stop_callback_.emplace(
     318               8 :                 caller_env->stop_token,
     319               8 :                 typename state_type::stop_callback_fn{&state_->stop_source_});
     320                 : 
     321               8 :             if(caller_env->stop_token.stop_requested())
     322               4 :                 state_->stop_source_.request_stop();
     323                 :         }
     324                 : 
     325                 :         // CRITICAL: If the last task finishes synchronously then the parent
     326                 :         // coroutine resumes, destroying its frame, and destroying this object
     327                 :         // prior to the completion of await_suspend. Therefore, await_suspend
     328                 :         // must ensure `this` cannot be referenced after calling `launch_one`
     329                 :         // for the last time.
     330              61 :         auto token = state_->stop_source_.get_token();
     331              62 :         [&]<std::size_t... Is>(std::index_sequence<Is...>) {
     332              61 :             (..., launch_one<Is>(caller_env->executor, token));
     333              61 :         }(std::index_sequence_for<Awaitables...>{});
     334                 : 
     335                 :         // Let signal_completion() handle resumption
     336             122 :         return std::noop_coroutine();
     337              61 :     }
     338                 : 
     339              61 :     void await_resume() const noexcept
     340                 :     {
     341                 :         // Results are extracted by the when_all coroutine from state
     342              61 :     }
     343                 : 
     344                 : private:
     345                 :     template<std::size_t I>
     346             134 :     void launch_one(executor_ref caller_ex, std::stop_token token)
     347                 :     {
     348             134 :         auto runner = make_when_all_runner<I>(
     349             134 :             std::move(std::get<I>(*awaitables_)), state_);
     350                 : 
     351             134 :         auto h = runner.release();
     352             134 :         h.promise().state_ = state_;
     353             134 :         h.promise().env_ = io_env{caller_ex, token, state_->caller_env_->frame_allocator};
     354                 : 
     355             134 :         std::coroutine_handle<> ch{h};
     356             134 :         state_->runner_handles_[I] = ch;
     357             134 :         state_->caller_env_->executor.post(ch);
     358             268 :     }
     359                 : };
     360                 : 
     361                 : /** Helper to extract a single result, returning empty tuple for void.
     362                 :     This is a separate function to work around a GCC-11 ICE that occurs
     363                 :     when using nested immediately-invoked lambdas with pack expansion.
     364                 : */
     365                 : template<std::size_t I, typename... Ts>
     366              59 : auto extract_single_result(when_all_state<Ts...>& state)
     367                 : {
     368                 :     using T = std::tuple_element_t<I, std::tuple<Ts...>>;
     369                 :     if constexpr (std::is_void_v<T>)
     370               4 :         return std::tuple<>();
     371                 :     else
     372              55 :         return std::make_tuple(std::move(std::get<I>(state.results_)).get());
     373                 : }
     374                 : 
     375                 : /** Extract results from state, filtering void types.
     376                 : */
     377                 : template<typename... Ts>
     378              25 : auto extract_results(when_all_state<Ts...>& state)
     379                 : {
     380              45 :     return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
     381              26 :         return std::tuple_cat(extract_single_result<Is>(state)...);
     382              50 :     }(std::index_sequence_for<Ts...>{});
     383                 : }
     384                 : 
     385                 : } // namespace detail
     386                 : 
     387                 : /** Compute a tuple type with void types filtered out.
     388                 : 
     389                 :     Returns void when all types are void (P2300 aligned),
     390                 :     otherwise returns a std::tuple with void types removed.
     391                 : 
     392                 :     Example: non_void_tuple_t<int, void, string> = std::tuple<int, string>
     393                 :     Example: non_void_tuple_t<void, void> = void
     394                 : */
     395                 : template<typename... Ts>
     396                 : using non_void_tuple_t = std::conditional_t<
     397                 :     std::is_same_v<detail::filter_void_tuple_t<Ts...>, std::tuple<>>,
     398                 :     void,
     399                 :     detail::filter_void_tuple_t<Ts...>>;
     400                 : 
     401                 : /** Execute multiple awaitables concurrently and collect their results.
     402                 : 
     403                 :     Launches all awaitables simultaneously and waits for all to complete
     404                 :     before returning. Results are collected in input order. If any
     405                 :     awaitable throws, cancellation is requested for siblings and the first
     406                 :     exception is rethrown after all awaitables complete.
     407                 : 
     408                 :     @li All child awaitables run concurrently on the caller's executor
     409                 :     @li Results are returned as a tuple in input order
     410                 :     @li Void-returning awaitables do not contribute to the result tuple
     411                 :     @li If all awaitables return void, `when_all` returns `task<void>`
     412                 :     @li First exception wins; subsequent exceptions are discarded
     413                 :     @li Stop is requested for siblings on first error
     414                 :     @li Completes only after all children have finished
     415                 : 
     416                 :     @par Thread Safety
     417                 :     The returned task must be awaited from a single execution context.
     418                 :     Child awaitables execute concurrently but complete through the caller's
     419                 :     executor.
     420                 : 
     421                 :     @param awaitables The awaitables to execute concurrently. Each must
     422                 :         satisfy @ref IoAwaitable and is consumed (moved-from) when
     423                 :         `when_all` is awaited.
     424                 : 
     425                 :     @return A task yielding a tuple of non-void results. Returns
     426                 :         `task<void>` when all input awaitables return void.
     427                 : 
     428                 :     @par Example
     429                 : 
     430                 :     @code
     431                 :     task<> example()
     432                 :     {
     433                 :         // Concurrent fetch, results collected in order
     434                 :         auto [user, posts] = co_await when_all(
     435                 :             fetch_user( id ),      // task<User>
     436                 :             fetch_posts( id )      // task<std::vector<Post>>
     437                 :         );
     438                 : 
     439                 :         // Void awaitables don't contribute to result
     440                 :         co_await when_all(
     441                 :             log_event( "start" ),  // task<void>
     442                 :             notify_user( id )      // task<void>
     443                 :         );
     444                 :         // Returns task<void>, no result tuple
     445                 :     }
     446                 :     @endcode
     447                 : 
     448                 :     @see IoAwaitable, task
     449                 : */
     450                 : template<IoAwaitable... As>
     451              61 : [[nodiscard]] auto when_all(As... awaitables)
     452                 :     -> task<non_void_tuple_t<awaitable_result_t<As>...>>
     453                 : {
     454                 :     using result_type = non_void_tuple_t<awaitable_result_t<As>...>;
     455                 : 
     456                 :     // State is stored in the coroutine frame, using the frame allocator
     457                 :     detail::when_all_state<awaitable_result_t<As>...> state;
     458                 : 
     459                 :     // Store awaitables in the frame
     460                 :     std::tuple<As...> awaitable_tuple(std::move(awaitables)...);
     461                 : 
     462                 :     // Launch all awaitables and wait for completion
     463                 :     co_await detail::when_all_launcher<As...>(&awaitable_tuple, &state);
     464                 : 
     465                 :     // Propagate first exception if any.
     466                 :     // Safe without explicit acquire: capture_exception() is sequenced-before
     467                 :     // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
     468                 :     // last task's decrement that resumes this coroutine.
     469                 :     if(state.first_exception_)
     470                 :         std::rethrow_exception(state.first_exception_);
     471                 : 
     472                 :     // Extract and return results
     473                 :     if constexpr (std::is_void_v<result_type>)
     474                 :         co_return;
     475                 :     else
     476                 :         co_return detail::extract_results(state);
     477             122 : }
     478                 : 
     479                 : } // namespace capy
     480                 : } // namespace boost
     481                 : 
     482                 : #endif
        

Generated by: LCOV version 2.3