include/boost/capy/when_all.hpp

96.9% Lines (95/98) 91.3% Functions (484/530)
Line TLA Hits Source Code
1 //
2 // Copyright (c) 2026 Steve Gerbino
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 // Official repository: https://github.com/cppalliance/capy
8 //
9
10 #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 #define BOOST_CAPY_WHEN_ALL_HPP
12
13 #include <boost/capy/detail/config.hpp>
14 #include <boost/capy/concept/executor.hpp>
15 #include <boost/capy/concept/io_awaitable.hpp>
16 #include <coroutine>
17 #include <boost/capy/ex/io_env.hpp>
18 #include <boost/capy/ex/frame_allocator.hpp>
19 #include <boost/capy/task.hpp>
20
21 #include <array>
22 #include <atomic>
23 #include <exception>
24 #include <optional>
25 #include <stop_token>
26 #include <tuple>
27 #include <type_traits>
28 #include <utility>
29
30 namespace boost {
31 namespace capy {
32
33 namespace detail {
34
35 /** Type trait to filter void types from a tuple.
36
37 Void-returning tasks do not contribute a value to the result tuple.
38 This trait computes the filtered result type.
39
40 Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
41 */
42 template<typename T>
43 using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
44
45 template<typename... Ts>
46 using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
47
48 /** Holds the result of a single task within when_all.
49 */
50 template<typename T>
51 struct result_holder
52 {
53 std::optional<T> value_;
54
55 62x void set(T v)
56 {
57 62x value_ = std::move(v);
58 62x }
59
60 55x T get() &&
61 {
62 55x return std::move(*value_);
63 }
64 };
65
66 /** Specialization for void tasks - no value storage needed.
67 */
68 template<>
69 struct result_holder<void>
70 {
71 };
72
73 /** Shared state for when_all operation.
74
75 @tparam Ts The result types of the tasks.
76 */
77 template<typename... Ts>
78 struct when_all_state
79 {
80 static constexpr std::size_t task_count = sizeof...(Ts);
81
82 // Completion tracking - when_all waits for all children
83 std::atomic<std::size_t> remaining_count_;
84
85 // Result storage in input order
86 std::tuple<result_holder<Ts>...> results_;
87
88 // Runner handles - destroyed in await_resume while allocator is valid
89 std::array<std::coroutine_handle<>, task_count> runner_handles_{};
90
91 // Exception storage - first error wins, others discarded
92 std::atomic<bool> has_exception_{false};
93 std::exception_ptr first_exception_;
94
95 // Stop propagation - on error, request stop for siblings
96 std::stop_source stop_source_;
97
98 // Connects parent's stop_token to our stop_source
99 struct stop_callback_fn
100 {
101 std::stop_source* source_;
102 4x void operator()() const { source_->request_stop(); }
103 };
104 using stop_callback_t = std::stop_callback<stop_callback_fn>;
105 std::optional<stop_callback_t> parent_stop_callback_;
106
107 // Parent resumption
108 std::coroutine_handle<> continuation_;
109 io_env const* caller_env_ = nullptr;
110
111 61x when_all_state()
112 61x : remaining_count_(task_count)
113 {
114 61x }
115
116 // Runners self-destruct in final_suspend. No destruction needed here.
117
118 /** Capture an exception (first one wins).
119 */
120 20x void capture_exception(std::exception_ptr ep)
121 {
122 20x bool expected = false;
123 20x if(has_exception_.compare_exchange_strong(
124 expected, true, std::memory_order_relaxed))
125 17x first_exception_ = ep;
126 20x }
127
128 };
129
130 /** Wrapper coroutine that intercepts task completion.
131
132 This runner awaits its assigned task and stores the result in
133 the shared state, or captures the exception and requests stop.
134 */
135 template<typename T, typename... Ts>
136 struct when_all_runner
137 {
138 struct promise_type // : frame_allocating_base // DISABLED FOR TESTING
139 {
140 when_all_state<Ts...>* state_ = nullptr;
141 io_env env_;
142
143 134x when_all_runner get_return_object()
144 {
145 134x return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
146 }
147
148 134x std::suspend_always initial_suspend() noexcept
149 {
150 134x return {};
151 }
152
153 134x auto final_suspend() noexcept
154 {
155 struct awaiter
156 {
157 promise_type* p_;
158
159 58x bool await_ready() const noexcept
160 {
161 58x return false;
162 }
163
164 58x auto await_suspend(std::coroutine_handle<> h) noexcept
165 {
166 // Extract everything needed before self-destruction.
167 58x auto* state = p_->state_;
168 58x auto* counter = &state->remaining_count_;
169 58x auto* caller_env = state->caller_env_;
170 58x auto cont = state->continuation_;
171
172 58x h.destroy();
173
174 // If last runner, dispatch parent for symmetric transfer.
175 58x auto remaining = counter->fetch_sub(1, std::memory_order_acq_rel);
176 58x if(remaining == 1)
177 29x return detail::symmetric_transfer(caller_env->executor.dispatch(cont));
178 29x return detail::symmetric_transfer(std::noop_coroutine());
179 }
180
181 void await_resume() const noexcept
182 {
183 }
184 };
185 134x return awaiter{this};
186 }
187
188 114x void return_void()
189 {
190 114x }
191
192 20x void unhandled_exception()
193 {
194 20x state_->capture_exception(std::current_exception());
195 // Request stop for sibling tasks
196 20x state_->stop_source_.request_stop();
197 20x }
198
199 template<class Awaitable>
200 struct transform_awaiter
201 {
202 std::decay_t<Awaitable> a_;
203 promise_type* p_;
204
205 134x bool await_ready()
206 {
207 134x return a_.await_ready();
208 }
209
210 134x decltype(auto) await_resume()
211 {
212 134x return a_.await_resume();
213 }
214
215 template<class Promise>
216 133x auto await_suspend(std::coroutine_handle<Promise> h)
217 {
218 using R = decltype(a_.await_suspend(h, &p_->env_));
219 if constexpr (std::is_same_v<R, std::coroutine_handle<>>)
220 133x return detail::symmetric_transfer(a_.await_suspend(h, &p_->env_));
221 else
222 return a_.await_suspend(h, &p_->env_);
223 }
224 };
225
226 template<class Awaitable>
227 134x auto await_transform(Awaitable&& a)
228 {
229 using A = std::decay_t<Awaitable>;
230 if constexpr (IoAwaitable<A>)
231 {
232 return transform_awaiter<Awaitable>{
233 268x std::forward<Awaitable>(a), this};
234 }
235 else
236 {
237 static_assert(sizeof(A) == 0, "requires IoAwaitable");
238 }
239 134x }
240 };
241
242 std::coroutine_handle<promise_type> h_;
243
244 134x explicit when_all_runner(std::coroutine_handle<promise_type> h)
245 134x : h_(h)
246 {
247 134x }
248
249 // Enable move for all clang versions - some versions need it
250 when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
251
252 // Non-copyable
253 when_all_runner(when_all_runner const&) = delete;
254 when_all_runner& operator=(when_all_runner const&) = delete;
255 when_all_runner& operator=(when_all_runner&&) = delete;
256
257 134x auto release() noexcept
258 {
259 134x return std::exchange(h_, nullptr);
260 }
261 };
262
263 /** Create a runner coroutine for a single awaitable.
264
265 Awaitable is passed directly to ensure proper coroutine frame storage.
266 */
267 template<std::size_t Index, IoAwaitable Awaitable, typename... Ts>
268 when_all_runner<awaitable_result_t<Awaitable>, Ts...>
269 134x make_when_all_runner(Awaitable inner, when_all_state<Ts...>* state)
270 {
271 using T = awaitable_result_t<Awaitable>;
272 if constexpr (std::is_void_v<T>)
273 {
274 co_await std::move(inner);
275 }
276 else
277 {
278 std::get<Index>(state->results_).set(co_await std::move(inner));
279 }
280 268x }
281
282 /** Internal awaitable that launches all runner coroutines and waits.
283
284 This awaitable is used inside the when_all coroutine to handle
285 the concurrent execution of child awaitables.
286 */
287 template<IoAwaitable... Awaitables>
288 class when_all_launcher
289 {
290 using state_type = when_all_state<awaitable_result_t<Awaitables>...>;
291
292 std::tuple<Awaitables...>* awaitables_;
293 state_type* state_;
294
295 public:
296 61x when_all_launcher(
297 std::tuple<Awaitables...>* awaitables,
298 state_type* state)
299 61x : awaitables_(awaitables)
300 61x , state_(state)
301 {
302 61x }
303
304 61x bool await_ready() const noexcept
305 {
306 61x return sizeof...(Awaitables) == 0;
307 }
308
309 61x std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation, io_env const* caller_env)
310 {
311 61x state_->continuation_ = continuation;
312 61x state_->caller_env_ = caller_env;
313
314 // Forward parent's stop requests to children
315 61x if(caller_env->stop_token.stop_possible())
316 {
317 16x state_->parent_stop_callback_.emplace(
318 8x caller_env->stop_token,
319 8x typename state_type::stop_callback_fn{&state_->stop_source_});
320
321 8x if(caller_env->stop_token.stop_requested())
322 4x state_->stop_source_.request_stop();
323 }
324
325 // CRITICAL: If the last task finishes synchronously then the parent
326 // coroutine resumes, destroying its frame, and destroying this object
327 // prior to the completion of await_suspend. Therefore, await_suspend
328 // must ensure `this` cannot be referenced after calling `launch_one`
329 // for the last time.
330 61x auto token = state_->stop_source_.get_token();
331 [&]<std::size_t... Is>(std::index_sequence<Is...>) {
332 30x (..., launch_one<Is>(caller_env->executor, token));
333 61x }(std::index_sequence_for<Awaitables...>{});
334
335 // Let signal_completion() handle resumption
336 122x return std::noop_coroutine();
337 61x }
338
339 61x void await_resume() const noexcept
340 {
341 // Results are extracted by the when_all coroutine from state
342 61x }
343
344 private:
345 template<std::size_t I>
346 134x void launch_one(executor_ref caller_ex, std::stop_token token)
347 {
348 134x auto runner = make_when_all_runner<I>(
349 134x std::move(std::get<I>(*awaitables_)), state_);
350
351 134x auto h = runner.release();
352 134x h.promise().state_ = state_;
353 134x h.promise().env_ = io_env{caller_ex, token, state_->caller_env_->frame_allocator};
354
355 134x std::coroutine_handle<> ch{h};
356 134x state_->runner_handles_[I] = ch;
357 134x state_->caller_env_->executor.post(ch);
358 268x }
359 };
360
361 /** Helper to extract a single result, returning empty tuple for void.
362 This is a separate function to work around a GCC-11 ICE that occurs
363 when using nested immediately-invoked lambdas with pack expansion.
364 */
365 template<std::size_t I, typename... Ts>
366 59x auto extract_single_result(when_all_state<Ts...>& state)
367 {
368 using T = std::tuple_element_t<I, std::tuple<Ts...>>;
369 if constexpr (std::is_void_v<T>)
370 4x return std::tuple<>();
371 else
372 55x return std::make_tuple(std::move(std::get<I>(state.results_)).get());
373 }
374
375 /** Extract results from state, filtering void types.
376 */
377 template<typename... Ts>
378 25x auto extract_results(when_all_state<Ts...>& state)
379 {
380 25x return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
381 5x return std::tuple_cat(extract_single_result<Is>(state)...);
382 50x }(std::index_sequence_for<Ts...>{});
383 }
384
385 } // namespace detail
386
387 /** Compute a tuple type with void types filtered out.
388
389 Returns void when all types are void (P2300 aligned),
390 otherwise returns a std::tuple with void types removed.
391
392 Example: non_void_tuple_t<int, void, string> = std::tuple<int, string>
393 Example: non_void_tuple_t<void, void> = void
394 */
395 template<typename... Ts>
396 using non_void_tuple_t = std::conditional_t<
397 std::is_same_v<detail::filter_void_tuple_t<Ts...>, std::tuple<>>,
398 void,
399 detail::filter_void_tuple_t<Ts...>>;
400
401 /** Execute multiple awaitables concurrently and collect their results.
402
403 Launches all awaitables simultaneously and waits for all to complete
404 before returning. Results are collected in input order. If any
405 awaitable throws, cancellation is requested for siblings and the first
406 exception is rethrown after all awaitables complete.
407
408 @li All child awaitables run concurrently on the caller's executor
409 @li Results are returned as a tuple in input order
410 @li Void-returning awaitables do not contribute to the result tuple
411 @li If all awaitables return void, `when_all` returns `task<void>`
412 @li First exception wins; subsequent exceptions are discarded
413 @li Stop is requested for siblings on first error
414 @li Completes only after all children have finished
415
416 @par Thread Safety
417 The returned task must be awaited from a single execution context.
418 Child awaitables execute concurrently but complete through the caller's
419 executor.
420
421 @param awaitables The awaitables to execute concurrently. Each must
422 satisfy @ref IoAwaitable and is consumed (moved-from) when
423 `when_all` is awaited.
424
425 @return A task yielding a tuple of non-void results. Returns
426 `task<void>` when all input awaitables return void.
427
428 @par Example
429
430 @code
431 task<> example()
432 {
433 // Concurrent fetch, results collected in order
434 auto [user, posts] = co_await when_all(
435 fetch_user( id ), // task<User>
436 fetch_posts( id ) // task<std::vector<Post>>
437 );
438
439 // Void awaitables don't contribute to result
440 co_await when_all(
441 log_event( "start" ), // task<void>
442 notify_user( id ) // task<void>
443 );
444 // Returns task<void>, no result tuple
445 }
446 @endcode
447
448 @see IoAwaitable, task
449 */
450 template<IoAwaitable... As>
451 61x [[nodiscard]] auto when_all(As... awaitables)
452 -> task<non_void_tuple_t<awaitable_result_t<As>...>>
453 {
454 using result_type = non_void_tuple_t<awaitable_result_t<As>...>;
455
456 // State is stored in the coroutine frame, using the frame allocator
457 detail::when_all_state<awaitable_result_t<As>...> state;
458
459 // Store awaitables in the frame
460 std::tuple<As...> awaitable_tuple(std::move(awaitables)...);
461
462 // Launch all awaitables and wait for completion
463 co_await detail::when_all_launcher<As...>(&awaitable_tuple, &state);
464
465 // Propagate first exception if any.
466 // Safe without explicit acquire: capture_exception() is sequenced-before
467 // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
468 // last task's decrement that resumes this coroutine.
469 if(state.first_exception_)
470 std::rethrow_exception(state.first_exception_);
471
472 // Extract and return results
473 if constexpr (std::is_void_v<result_type>)
474 co_return;
475 else
476 co_return detail::extract_results(state);
477 122x }
478
479 } // namespace capy
480 } // namespace boost
481
482 #endif
483