diff --git a/.githooks/pre-commit b/.githooks/pre-commit index 762160d..a433ed0 100755 --- a/.githooks/pre-commit +++ b/.githooks/pre-commit @@ -35,11 +35,15 @@ done cp .githooks/readme-template.md README.md template_contents=$(cat 'README.md') -coro_event_cpp_contents=$(cat 'examples/coro_event.cpp') -echo "${template_contents/\$\{EXAMPLE_CORO_EVENT_CPP\}/$coro_event_cpp_contents}" > README.md +example_contents=$(cat 'examples/coro_event.cpp') +echo "${template_contents/\$\{EXAMPLE_CORO_EVENT_CPP\}/$example_contents}" > README.md template_contents=$(cat 'README.md') -coro_latch_cpp_contents=$(cat 'examples/coro_latch.cpp') -echo "${template_contents/\$\{EXAMPLE_CORO_LATCH_CPP\}/$coro_latch_cpp_contents}" > README.md +example_contents=$(cat 'examples/coro_latch.cpp') +echo "${template_contents/\$\{EXAMPLE_CORO_LATCH_CPP\}/$example_contents}" > README.md + +template_contents=$(cat 'README.md') +example_contents=$(cat 'examples/coro_mutex.cpp') +echo "${template_contents/\$\{EXAMPLE_CORO_MUTEX_CPP\}/$example_contents}" > README.md git add README.md diff --git a/.githooks/readme-template.md b/.githooks/readme-template.md index 037bf86..1028a24 100644 --- a/.githooks/readme-template.md +++ b/.githooks/readme-template.md @@ -20,8 +20,8 @@ - coro::latch - coro::mutex - coro::sync_wait(awaitable) - - coro::when_all_awaitabe(awaitable...) -> coro::task... - - coro::when_all(awaitable...) -> T... (Future) + - coro::when_all(awaitable...) -> coro::task... + - coro::when_all_results(awaitable...) -> T... (Future) * Schedulers - coro::thread_pool for coroutine cooperative multitasking - coro::io_scheduler for driving i/o events, uses thread_pool for coroutine execution @@ -73,17 +73,30 @@ Expected output: ```bash $ ./examples/coro_latch latch task is now waiting on all children tasks... -work task 1 is working... -work task 1 is done, counting down on the latch -work task 2 is working... -work task 2 is done, counting down on the latch -work task 3 is working... -work task 3 is done, counting down on the latch -work task 4 is working... -work task 4 is done, counting down on the latch -work task 5 is working... -work task 5 is done, counting down on the latch -latch task children tasks completed, resuming. +worker task 1 is working... +worker task 2 is working... +worker task 3 is working... +worker task 4 is working... +worker task 5 is working... +worker task 1 is done, counting down on the latch +worker task 2 is done, counting down on the latch +worker task 3 is done, counting down on the latch +worker task 4 is done, counting down on the latch +worker task 5 is done, counting down on the latch +latch task dependency tasks completed, resuming. +``` + +### coro::mutex + +```C++ +${EXAMPLE_CORO_MUTEX_CPP} +``` + +Expected output, note that the output will vary from run to run based on how the thread pool workers +are scheduled and in what order they acquire the mutex lock: +```bash +$ ./examples/coro_mutex +1, 2, 3, 4, 5, 6, 7, 8, 10, 9, 12, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22, 20, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 48, 49, 46, 50, 51, 52, 53, 54, 55, 57, 58, 59, 56, 60, 62, 61, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, ``` ## Usage diff --git a/README.md b/README.md index 389b6ad..f1ad931 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,8 @@ - coro::latch - coro::mutex - coro::sync_wait(awaitable) - - coro::when_all_awaitabe(awaitable...) -> coro::task... - - coro::when_all(awaitable...) -> T... (Future) + - coro::when_all(awaitable...) -> coro::task... + - coro::when_all_results(awaitable...) -> T... (Future) * Schedulers - coro::thread_pool for coroutine cooperative multitasking - coro::io_scheduler for driving i/o events, uses thread_pool for coroutine execution @@ -68,11 +68,9 @@ int main() co_return; }; - // Synchronously wait until all the tasks are completed, this is intentionally - // starting the first 3 wait tasks prior to the final set task so the waiters suspend - // their coroutine before being resumed. - coro::sync_wait( - coro::when_all_awaitable(make_wait_task(e, 1), make_wait_task(e, 2), make_wait_task(e, 3), make_set_task(e))); + // Given more than a single task to synchronously wait on, use when_all() to execute all the + // tasks concurrently on this thread and then sync_wait() for them all to complete. + coro::sync_wait(coro::when_all(make_wait_task(e, 1), make_wait_task(e, 2), make_wait_task(e, 3), make_set_task(e))); } ``` @@ -98,35 +96,41 @@ have completed before proceeding. int main() { + // Complete worker tasks faster on a thread pool, using the io_scheduler version so the worker + // tasks can yield for a specific amount of time to mimic difficult work. The pool is only + // setup with a single thread to showcase yield_for(). + coro::io_scheduler tp{coro::io_scheduler::options{.pool = coro::thread_pool::options{.thread_count = 1}}}; + // This task will wait until the given latch setters have completed. auto make_latch_task = [](coro::latch& l) -> coro::task { + // It seems like the dependent worker tasks could be created here, but in that case it would + // be superior to simply do: `co_await coro::when_all(tasks);` + // It is also important to note that the last dependent task will resume the waiting latch + // task prior to actually completing -- thus the dependent task's frame could be destroyed + // by the latch task completing before it gets a chance to finish after calling resume() on + // the latch task! + std::cout << "latch task is now waiting on all children tasks...\n"; co_await l; - std::cout << "latch task children tasks completed, resuming.\n"; + std::cout << "latch task dependency tasks completed, resuming.\n"; co_return; }; // This task does 'work' and counts down on the latch when completed. The final child task to // complete will end up resuming the latch task when the latch's count reaches zero. - auto make_worker_task = [](coro::latch& l, int64_t i) -> coro::task { - std::cout << "work task " << i << " is working...\n"; - std::cout << "work task " << i << " is done, counting down on the latch\n"; + auto make_worker_task = [](coro::io_scheduler& tp, coro::latch& l, int64_t i) -> coro::task { + // Schedule the worker task onto the thread pool. + co_await tp.schedule(); + std::cout << "worker task " << i << " is working...\n"; + // Do some expensive calculations, yield to mimic work...! Its also important to never use + // std::this_thread::sleep_for() within the context of coroutines, it will block the thread + // and other tasks that are ready to execute will be blocked. + co_await tp.yield_for(std::chrono::milliseconds{i * 20}); + std::cout << "worker task " << i << " is done, counting down on the latch\n"; l.count_down(); co_return; }; - // It is important to note that the latch task must not 'own' the worker tasks within its - // coroutine stack frame because the final worker task thread will execute the latch task upon - // setting the latch counter to zero. This means that: - // 1) final worker task calls count_down() => 0 - // 2) resume execution of latch task to its next suspend point or completion, IF completed - // then this coroutine's stack frame is destroyed! - // 3) final worker task continues exection - // If the latch task 'own's the worker task objects then they will destruct prior to step (3) - // if the latch task completes on that resume, and it will be attempting to execute an already - // destructed coroutine frame. - // This example correctly has the latch task and all its waiting tasks on the same scope/frame - // to avoid this issue. const int64_t num_tasks{5}; coro::latch l{num_tasks}; std::vector> tasks{}; @@ -135,11 +139,11 @@ int main() tasks.emplace_back(make_latch_task(l)); for (int64_t i = 1; i <= num_tasks; ++i) { - tasks.emplace_back(make_worker_task(l, i)); + tasks.emplace_back(make_worker_task(tp, l, i)); } // Wait for all tasks to complete. - coro::sync_wait(coro::when_all_awaitable(tasks)); + coro::sync_wait(coro::when_all(tasks)); } ``` @@ -147,17 +151,67 @@ Expected output: ```bash $ ./examples/coro_latch latch task is now waiting on all children tasks... -work task 1 is working... -work task 1 is done, counting down on the latch -work task 2 is working... -work task 2 is done, counting down on the latch -work task 3 is working... -work task 3 is done, counting down on the latch -work task 4 is working... -work task 4 is done, counting down on the latch -work task 5 is working... -work task 5 is done, counting down on the latch -latch task children tasks completed, resuming. +worker task 1 is working... +worker task 2 is working... +worker task 3 is working... +worker task 4 is working... +worker task 5 is working... +worker task 1 is done, counting down on the latch +worker task 2 is done, counting down on the latch +worker task 3 is done, counting down on the latch +worker task 4 is done, counting down on the latch +worker task 5 is done, counting down on the latch +latch task dependency tasks completed, resuming. +``` + +### coro::mutex + +```C++ +#include +#include + +int main() +{ + coro::thread_pool tp{coro::thread_pool::options{.thread_count = 4}}; + std::vector output{}; + coro::mutex mutex; + + auto make_critical_section_task = [&](uint64_t i) -> coro::task { + co_await tp.schedule(); + // To acquire a mutex lock co_await its lock() function. Upon acquiring the lock the + // lock() function returns a coro::scoped_lock that holds the mutex and automatically + // unlocks the mutex upon destruction. This behaves just like std::scoped_lock. + { + auto scoped_lock = co_await mutex.lock(); + output.emplace_back(i); + } // <-- scoped lock unlocks the mutex here. + co_return; + }; + + const size_t num_tasks{100}; + std::vector> tasks{}; + tasks.reserve(num_tasks); + for (size_t i = 1; i <= num_tasks; ++i) + { + tasks.emplace_back(make_critical_section_task(i)); + } + + coro::sync_wait(coro::when_all(tasks)); + + // The output will be variable per run depending on how the tasks are picked up on the + // thread pool workers. + for (const auto& value : output) + { + std::cout << value << ", "; + } +} +``` + +Expected output, note that the output will vary from run to run based on how the thread pool workers +are scheduled and in what order they acquire the mutex lock: +```bash +$ ./examples/coro_mutex +1, 2, 3, 4, 5, 6, 7, 8, 10, 9, 12, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22, 20, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 48, 49, 46, 50, 51, 52, 53, 54, 55, 57, 58, 59, 56, 60, 62, 61, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, ``` ## Usage diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 6f87bd9..e95d541 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -9,9 +9,14 @@ add_executable(coro_latch coro_latch.cpp) target_compile_features(coro_latch PUBLIC cxx_std_20) target_link_libraries(coro_latch PUBLIC libcoro) +add_executable(coro_mutex coro_mutex.cpp) +target_compile_features(coro_mutex PUBLIC cxx_std_20) +target_link_libraries(coro_mutex PUBLIC libcoro) + if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU") target_compile_options(coro_event PUBLIC -fcoroutines -Wall -Wextra -pipe) target_compile_options(coro_latch PUBLIC -fcoroutines -Wall -Wextra -pipe) + target_compile_options(coro_mutex PUBLIC -fcoroutines -Wall -Wextra -pipe) elseif(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") message(FATAL_ERROR "Clang is currently not supported.") else() diff --git a/examples/coro_event.cpp b/examples/coro_event.cpp index 3b9d182..4469294 100644 --- a/examples/coro_event.cpp +++ b/examples/coro_event.cpp @@ -20,9 +20,7 @@ int main() co_return; }; - // Synchronously wait until all the tasks are completed, this is intentionally - // starting the first 3 wait tasks prior to the final set task so the waiters suspend - // their coroutine before being resumed. - coro::sync_wait( - coro::when_all_awaitable(make_wait_task(e, 1), make_wait_task(e, 2), make_wait_task(e, 3), make_set_task(e))); + // Given more than a single task to synchronously wait on, use when_all() to execute all the + // tasks concurrently on this thread and then sync_wait() for them all to complete. + coro::sync_wait(coro::when_all(make_wait_task(e, 1), make_wait_task(e, 2), make_wait_task(e, 3), make_set_task(e))); } diff --git a/examples/coro_latch.cpp b/examples/coro_latch.cpp index f106518..47371bb 100644 --- a/examples/coro_latch.cpp +++ b/examples/coro_latch.cpp @@ -3,35 +3,41 @@ int main() { + // Complete worker tasks faster on a thread pool, using the io_scheduler version so the worker + // tasks can yield for a specific amount of time to mimic difficult work. The pool is only + // setup with a single thread to showcase yield_for(). + coro::io_scheduler tp{coro::io_scheduler::options{.pool = coro::thread_pool::options{.thread_count = 1}}}; + // This task will wait until the given latch setters have completed. auto make_latch_task = [](coro::latch& l) -> coro::task { + // It seems like the dependent worker tasks could be created here, but in that case it would + // be superior to simply do: `co_await coro::when_all(tasks);` + // It is also important to note that the last dependent task will resume the waiting latch + // task prior to actually completing -- thus the dependent task's frame could be destroyed + // by the latch task completing before it gets a chance to finish after calling resume() on + // the latch task! + std::cout << "latch task is now waiting on all children tasks...\n"; co_await l; - std::cout << "latch task children tasks completed, resuming.\n"; + std::cout << "latch task dependency tasks completed, resuming.\n"; co_return; }; // This task does 'work' and counts down on the latch when completed. The final child task to // complete will end up resuming the latch task when the latch's count reaches zero. - auto make_worker_task = [](coro::latch& l, int64_t i) -> coro::task { - std::cout << "work task " << i << " is working...\n"; - std::cout << "work task " << i << " is done, counting down on the latch\n"; + auto make_worker_task = [](coro::io_scheduler& tp, coro::latch& l, int64_t i) -> coro::task { + // Schedule the worker task onto the thread pool. + co_await tp.schedule(); + std::cout << "worker task " << i << " is working...\n"; + // Do some expensive calculations, yield to mimic work...! Its also important to never use + // std::this_thread::sleep_for() within the context of coroutines, it will block the thread + // and other tasks that are ready to execute will be blocked. + co_await tp.yield_for(std::chrono::milliseconds{i * 20}); + std::cout << "worker task " << i << " is done, counting down on the latch\n"; l.count_down(); co_return; }; - // It is important to note that the latch task must not 'own' the worker tasks within its - // coroutine stack frame because the final worker task thread will execute the latch task upon - // setting the latch counter to zero. This means that: - // 1) final worker task calls count_down() => 0 - // 2) resume execution of latch task to its next suspend point or completion, IF completed - // then this coroutine's stack frame is destroyed! - // 3) final worker task continues exection - // If the latch task 'own's the worker task objects then they will destruct prior to step (3) - // if the latch task completes on that resume, and it will be attempting to execute an already - // destructed coroutine frame. - // This example correctly has the latch task and all its waiting tasks on the same scope/frame - // to avoid this issue. const int64_t num_tasks{5}; coro::latch l{num_tasks}; std::vector> tasks{}; @@ -40,9 +46,9 @@ int main() tasks.emplace_back(make_latch_task(l)); for (int64_t i = 1; i <= num_tasks; ++i) { - tasks.emplace_back(make_worker_task(l, i)); + tasks.emplace_back(make_worker_task(tp, l, i)); } // Wait for all tasks to complete. - coro::sync_wait(coro::when_all_awaitable(tasks)); + coro::sync_wait(coro::when_all(tasks)); } diff --git a/examples/coro_mutex.cpp b/examples/coro_mutex.cpp new file mode 100644 index 0000000..4a26c8d --- /dev/null +++ b/examples/coro_mutex.cpp @@ -0,0 +1,38 @@ +#include +#include + +int main() +{ + coro::thread_pool tp{coro::thread_pool::options{.thread_count = 4}}; + std::vector output{}; + coro::mutex mutex; + + auto make_critical_section_task = [&](uint64_t i) -> coro::task { + co_await tp.schedule(); + // To acquire a mutex lock co_await its lock() function. Upon acquiring the lock the + // lock() function returns a coro::scoped_lock that holds the mutex and automatically + // unlocks the mutex upon destruction. This behaves just like std::scoped_lock. + { + auto scoped_lock = co_await mutex.lock(); + output.emplace_back(i); + } // <-- scoped lock unlocks the mutex here. + co_return; + }; + + const size_t num_tasks{100}; + std::vector> tasks{}; + tasks.reserve(num_tasks); + for (size_t i = 1; i <= num_tasks; ++i) + { + tasks.emplace_back(make_critical_section_task(i)); + } + + coro::sync_wait(coro::when_all(tasks)); + + // The output will be variable per run depending on how the tasks are picked up on the + // thread pool workers. + for (const auto& value : output) + { + std::cout << value << ", "; + } +} diff --git a/inc/coro/latch.hpp b/inc/coro/latch.hpp index 815f29e..cec59c3 100644 --- a/inc/coro/latch.hpp +++ b/inc/coro/latch.hpp @@ -1,6 +1,7 @@ #pragma once #include "coro/event.hpp" +#include "coro/thread_pool.hpp" #include @@ -41,6 +42,7 @@ public: auto remaining() const noexcept -> std::size_t { return m_count.load(std::memory_order::acquire); } /** + * If the latch counter goes to zero then the task awaiting the latch is resumed. * @param n The number of tasks to complete towards the latch, defaults to 1. */ auto count_down(std::ptrdiff_t n = 1) noexcept -> void @@ -51,6 +53,20 @@ public: } } + /** + * If the latch counter goes to then the task awaiting the latch is resumed on the given + * thread pool. + * @param tp The thread pool to schedule the task that is waiting on the latch on. + * @param n The number of tasks to complete towards the latch, defaults to 1. + */ + auto count_down(coro::thread_pool& tp, std::ptrdiff_t n = 1) noexcept -> void + { + if (m_count.fetch_sub(n, std::memory_order::acq_rel) <= n) + { + m_event.set(tp); + } + } + auto operator co_await() const noexcept -> event::awaiter { return m_event.operator co_await(); } private: diff --git a/inc/coro/mutex.hpp b/inc/coro/mutex.hpp index 50c7b33..3db2aff 100644 --- a/inc/coro/mutex.hpp +++ b/inc/coro/mutex.hpp @@ -2,36 +2,55 @@ #include #include -#include #include namespace coro { +class mutex; + +class scoped_lock +{ + friend class mutex; + +public: + enum class lock_strategy + { + /// The lock is already acquired, adopt it as the new owner. + adopt + }; + + explicit scoped_lock(mutex& m, lock_strategy strategy = lock_strategy::adopt) : m_mutex(&m) + { + // Future -> support acquiring the lock? Not sure how to do that without being able to + // co_await in the constructor. + (void)strategy; + } + ~scoped_lock(); + + scoped_lock(const scoped_lock&) = delete; + scoped_lock(scoped_lock&& other) : m_mutex(std::exchange(other.m_mutex, nullptr)) {} + auto operator=(const scoped_lock&) -> scoped_lock& = delete; + auto operator =(scoped_lock&& other) -> scoped_lock& + { + if (std::addressof(other) != this) + { + m_mutex = std::exchange(other.m_mutex, nullptr); + } + return *this; + } + + /** + * Unlocks the scoped lock prior to it going out of scope. + */ + auto unlock() -> void; + +private: + mutex* m_mutex{nullptr}; +}; + class mutex { public: - struct scoped_lock - { - friend class mutex; - - scoped_lock(mutex& m) : m_mutex(m) {} - ~scoped_lock() { m_mutex.unlock(); } - - mutex& m_mutex; - }; - - struct awaiter - { - awaiter(mutex& m) noexcept : m_mutex(m) {} - - auto await_ready() const noexcept -> bool; - auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool; - auto await_resume() noexcept -> scoped_lock; - - mutex& m_mutex; - std::coroutine_handle<> m_awaiting_coroutine; - }; - explicit mutex() noexcept = default; ~mutex() = default; @@ -40,16 +59,45 @@ public: auto operator=(const mutex&) -> mutex& = delete; auto operator=(mutex&&) -> mutex& = delete; - auto lock() -> awaiter; + struct lock_operation + { + explicit lock_operation(mutex& m) : m_mutex(m) {} + + auto await_ready() const noexcept -> bool { return m_mutex.try_lock(); } + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool; + auto await_resume() noexcept -> scoped_lock { return scoped_lock{m_mutex}; } + + mutex& m_mutex; + std::coroutine_handle<> m_awaiting_coroutine; + lock_operation* m_next{nullptr}; + }; + + /** + * To acquire the mutex's lock co_await this function. Upon acquiring the lock it returns + * a coro::scoped_lock which will hold the mutex until the coro::scoped_lock destructs. + * @return A co_await'able operation to acquire the mutex. + */ + [[nodiscard]] auto lock() -> lock_operation { return lock_operation{*this}; }; + + /** + * Attempts to lock the mutex. + * @return True if the mutex lock was acquired, otherwise false. + */ auto try_lock() -> bool; + + /** + * Releases the mutex's lock. + */ auto unlock() -> void; private: - friend class scoped_lock; + // friend class scoped_lock; + friend class lock_operation; - std::atomic m_state{false}; - std::mutex m_waiter_mutex{}; - std::deque m_waiter_list{}; + std::atomic m_state{false}; + std::mutex m_waiter_mutex{}; + lock_operation* m_head_waiter{nullptr}; + lock_operation* m_tail_waiter{nullptr}; }; } // namespace coro diff --git a/inc/coro/sync_wait.hpp b/inc/coro/sync_wait.hpp index 46dfbeb..14e99fb 100644 --- a/inc/coro/sync_wait.hpp +++ b/inc/coro/sync_wait.hpp @@ -1,6 +1,7 @@ #pragma once #include "coro/concepts/awaitable.hpp" +#include "coro/when_all.hpp" #include #include @@ -183,28 +184,28 @@ private: }; template< - concepts::awaitable awaitable, - typename return_type = concepts::awaitable_traits::awaiter_return_type> -static auto make_sync_wait_task(awaitable&& a) -> sync_wait_task + concepts::awaitable awaitable_type, + typename return_type = concepts::awaitable_traits::awaiter_return_type> +static auto make_sync_wait_task(awaitable_type&& a) -> sync_wait_task { if constexpr (std::is_void_v) { - co_await std::forward(a); + co_await std::forward(a); co_return; } else { - co_yield co_await std::forward(a); + co_yield co_await std::forward(a); } } } // namespace detail -template -auto sync_wait(awaitable&& a) -> decltype(auto) +template +auto sync_wait(awaitable_type&& a) -> decltype(auto) { detail::sync_wait_event e{}; - auto task = detail::make_sync_wait_task(std::forward(a)); + auto task = detail::make_sync_wait_task(std::forward(a)); task.start(e); e.wait(); diff --git a/inc/coro/when_all.hpp b/inc/coro/when_all.hpp index 5dde715..a5480f9 100644 --- a/inc/coro/when_all.hpp +++ b/inc/coro/when_all.hpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace coro { @@ -453,7 +454,7 @@ static auto make_when_all_task(awaitable&& a) -> when_all_task } // namespace detail template -[[nodiscard]] auto when_all_awaitable(awaitables_type&&... awaitables) +[[nodiscard]] auto when_all(awaitables_type&&... awaitables) { return detail::when_all_ready_awaitable::awaiter_return_type>...>>( @@ -461,9 +462,10 @@ template } template< - concepts::awaitable awaitable, - typename return_type = concepts::awaitable_traits::awaiter_return_type> -[[nodiscard]] auto when_all_awaitable(std::vector& awaitables) + concepts::awaitable awaitable_type, + typename return_type = concepts::awaitable_traits::awaiter_return_type, + typename allocator_type = std::allocator> +[[nodiscard]] auto when_all(std::vector& awaitables) -> detail::when_all_ready_awaitable>> { std::vector> tasks; diff --git a/src/mutex.cpp b/src/mutex.cpp index e8ef2da..da2a054 100644 --- a/src/mutex.cpp +++ b/src/mutex.cpp @@ -2,9 +2,49 @@ namespace coro { -auto mutex::lock() -> awaiter +scoped_lock::~scoped_lock() { - return awaiter(*this); + if (m_mutex != nullptr) + { + m_mutex->unlock(); + } +} + +auto scoped_lock::unlock() -> void +{ + if (m_mutex != nullptr) + { + m_mutex->unlock(); + m_mutex = nullptr; + } +} + +auto mutex::lock_operation::await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool +{ + std::scoped_lock lk{m_mutex.m_waiter_mutex}; + if (m_mutex.try_lock()) + { + // If we just straight up acquire the lock, don't suspend. + return false; + } + + // The lock is currently held, so append ourself to the waiter list. + if (m_mutex.m_tail_waiter == nullptr) + { + // If there are no current waiters this lock operation is the head and tail. + m_mutex.m_head_waiter = this; + m_mutex.m_tail_waiter = this; + } + else + { + // Update the current tail pointer to ourself. + m_mutex.m_tail_waiter->m_next = this; + // Update the tail pointer on the mutex to ourself. + m_mutex.m_tail_waiter = this; + } + + m_awaiting_coroutine = awaiting_coroutine; + return true; } auto mutex::try_lock() -> bool @@ -15,58 +55,36 @@ auto mutex::try_lock() -> bool auto mutex::unlock() -> void { - // Get the next waiter before releasing the lock. - awaiter* next{nullptr}; + // Acquire the next waiter before releasing _or_ moving ownship of the lock. + lock_operation* next{nullptr}; { std::scoped_lock lk{m_waiter_mutex}; - if (!m_waiter_list.empty()) + if (m_head_waiter != nullptr) { - next = m_waiter_list.front(); - m_waiter_list.pop_front(); + next = m_head_waiter; + m_head_waiter = m_head_waiter->m_next; + + // Null out the tail waiter if this was the last waiter. + if (m_head_waiter == nullptr) + { + m_tail_waiter = nullptr; + } + } + else + { + // If there were no waiters, release the lock. This is done under the waiter list being + // locked so another thread doesn't add themselves to the waiter list before the lock + // is actually released. + m_state.exchange(false, std::memory_order::release); } } - // Unlock the mutex - m_state.exchange(false, std::memory_order::release); - - // If there was a awaiter, resume it. Here would be good place to _resume_ the waiter onto - // the thread pool to distribute the work, this currently implementation will end up having - // every waiter on the mutex jump onto a single thread. + // If there were any waiters resume the next in line, this will pass ownership of the mutex to + // that waiter, only the final waiter in the list actually unlocks the mutex. if (next != nullptr) { next->m_awaiting_coroutine.resume(); } } -auto mutex::awaiter::await_ready() const noexcept -> bool -{ - return m_mutex.try_lock(); -} - -auto mutex::awaiter::await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool -{ - m_awaiting_coroutine = awaiting_coroutine; - - { - // Its possible between await_ready() and await_suspend() the lock was released, - // if thats the case acquire it immediately. - std::scoped_lock lk{m_mutex.m_waiter_mutex}; - if (m_mutex.m_waiter_list.empty() && m_mutex.try_lock()) - { - return false; - } - - // Ok its still held, add ourself to the waiter list. - m_mutex.m_waiter_list.emplace_back(this); - } - - // The mutex is still locked and we've added this to the waiter list, suspend now. - return true; -} - -auto mutex::awaiter::await_resume() noexcept -> scoped_lock -{ - return scoped_lock{m_mutex}; -} - } // namespace coro diff --git a/test/bench.cpp b/test/bench.cpp index d24ded8..7d40e5d 100644 --- a/test/bench.cpp +++ b/test/bench.cpp @@ -64,7 +64,7 @@ TEST_CASE("benchmark counter func coro::sync_wait(awaitable)", "[benchmark]") REQUIRE(counter == iterations); } -TEST_CASE("benchmark counter func coro::sync_wait(coro::when_all_awaitable(awaitable)) x10", "[benchmark]") +TEST_CASE("benchmark counter func coro::sync_wait(coro::when_all(awaitable)) x10", "[benchmark]") { constexpr std::size_t iterations = default_iterations; uint64_t counter{0}; @@ -74,13 +74,12 @@ TEST_CASE("benchmark counter func coro::sync_wait(coro::when_all_awaitable(await for (std::size_t i = 0; i < iterations; i += 10) { - auto tasks = coro::sync_wait(coro::when_all_awaitable(f(), f(), f(), f(), f(), f(), f(), f(), f(), f())); + auto tasks = coro::sync_wait(coro::when_all(f(), f(), f(), f(), f(), f(), f(), f(), f(), f())); std::apply([&counter](auto&&... t) { ((counter += t.return_value()), ...); }, tasks); } - print_stats( - "benchmark counter func coro::sync_wait(coro::when_all_awaitable(awaitable))", iterations, start, sc::now()); + print_stats("benchmark counter func coro::sync_wait(coro::when_all(awaitable))", iterations, start, sc::now()); REQUIRE(counter == iterations); } @@ -171,7 +170,7 @@ TEST_CASE("benchmark counter task scheduler{1} yield", "[benchmark]") tasks.emplace_back(make_task()); } - coro::sync_wait(coro::when_all_awaitable(tasks)); + coro::sync_wait(coro::when_all(tasks)); auto stop = sc::now(); print_stats("benchmark counter task scheduler{1} yield", ops, start, stop); @@ -204,7 +203,7 @@ TEST_CASE("benchmark counter task scheduler{1} yield_for", "[benchmark]") tasks.emplace_back(make_task()); } - coro::sync_wait(coro::when_all_awaitable(tasks)); + coro::sync_wait(coro::when_all(tasks)); auto stop = sc::now(); print_stats("benchmark counter task scheduler{1} yield", ops, start, stop); @@ -252,7 +251,7 @@ TEST_CASE("benchmark counter task scheduler await event from another coroutine", tasks.emplace_back(resume_func(i)); } - coro::sync_wait(coro::when_all_awaitable(tasks)); + coro::sync_wait(coro::when_all(tasks)); auto stop = sc::now(); print_stats("benchmark counter task scheduler await event from another coroutine", ops, start, stop); @@ -433,7 +432,7 @@ TEST_CASE("benchmark tcp_server echo server", "[benchmark]") { c.tasks.emplace_back(make_client_task(c)); } - coro::sync_wait(coro::when_all_awaitable(c.tasks)); + coro::sync_wait(coro::when_all(c.tasks)); c.scheduler.shutdown(); }}); } diff --git a/test/net/test_tcp_server.cpp b/test/net/test_tcp_server.cpp index 539c3f6..5cccf16 100644 --- a/test/net/test_tcp_server.cpp +++ b/test/net/test_tcp_server.cpp @@ -78,5 +78,5 @@ TEST_CASE("tcp_server ping server", "[tcp_server]") co_return; }; - coro::sync_wait(coro::when_all_awaitable(make_server_task(), make_client_task())); + coro::sync_wait(coro::when_all(make_server_task(), make_client_task())); } diff --git a/test/net/test_udp_peers.cpp b/test/net/test_udp_peers.cpp index 2745ad6..62f0014 100644 --- a/test/net/test_udp_peers.cpp +++ b/test/net/test_udp_peers.cpp @@ -41,7 +41,7 @@ TEST_CASE("udp one way") co_return; }; - coro::sync_wait(coro::when_all_awaitable(make_recv_task(), make_send_task())); + coro::sync_wait(coro::when_all(make_recv_task(), make_send_task())); } TEST_CASE("udp echo peers") @@ -110,7 +110,7 @@ TEST_CASE("udp echo peers") co_return; }; - coro::sync_wait(coro::when_all_awaitable( + coro::sync_wait(coro::when_all( make_peer_task(8081, 8080, false, peer2_msg, peer1_msg), make_peer_task(8080, 8081, true, peer1_msg, peer2_msg))); } diff --git a/test/test_io_scheduler.cpp b/test/test_io_scheduler.cpp index 66bfa9d..bd107c5 100644 --- a/test/test_io_scheduler.cpp +++ b/test/test_io_scheduler.cpp @@ -49,7 +49,7 @@ TEST_CASE("io_scheduler submit mutiple tasks", "[io_scheduler]") tasks.emplace_back(make_task()); } - coro::sync_wait(coro::when_all_awaitable(tasks)); + coro::sync_wait(coro::when_all(tasks)); REQUIRE(counter == n); } @@ -79,8 +79,7 @@ TEST_CASE("io_scheduler task with multiple events", "[io_scheduler]") e.set(); }; - coro::sync_wait( - coro::when_all_awaitable(make_wait_task(), make_set_task(e1), make_set_task(e2), make_set_task(e3))); + coro::sync_wait(coro::when_all(make_wait_task(), make_set_task(e1), make_set_task(e2), make_set_task(e3))); REQUIRE(counter == 3); @@ -107,7 +106,7 @@ TEST_CASE("io_scheduler task with read poll", "[io_scheduler]") co_return; }; - coro::sync_wait(coro::when_all_awaitable(make_poll_read_task(), make_poll_write_task())); + coro::sync_wait(coro::when_all(make_poll_read_task(), make_poll_write_task())); s.shutdown(); REQUIRE(s.empty()); @@ -134,7 +133,7 @@ TEST_CASE("io_scheduler task with read poll with timeout", "[io_scheduler]") co_return; }; - coro::sync_wait(coro::when_all_awaitable(make_poll_read_task(), make_poll_write_task())); + coro::sync_wait(coro::when_all(make_poll_read_task(), make_poll_write_task())); s.shutdown(); REQUIRE(s.empty()); @@ -182,7 +181,7 @@ TEST_CASE("io_scheduler task with read poll timeout", "[io_scheduler]") // co_return; // }; -// coro::sync_wait(coro::when_all_awaitable(make_poll_task(), make_close_task())); +// coro::sync_wait(coro::when_all(make_poll_task(), make_close_task())); // s.shutdown(); // REQUIRE(s.empty()); @@ -214,7 +213,7 @@ TEST_CASE("io_scheduler separate thread resume", "[io_scheduler]") co_return; }; - coro::sync_wait(coro::when_all_awaitable(make_s1_task(), make_s2_task())); + coro::sync_wait(coro::when_all(make_s1_task(), make_s2_task())); s1.shutdown(); REQUIRE(s1.empty()); @@ -307,8 +306,7 @@ TEST_CASE("io_scheduler with basic task", "[io_scheduler]") auto func = [&]() -> coro::task { co_await s.schedule(); - auto output_tasks = - co_await coro::when_all_awaitable(add_data(1), add_data(1), add_data(1), add_data(1), add_data(1)); + auto output_tasks = co_await coro::when_all(add_data(1), add_data(1), add_data(1), add_data(1), add_data(1)); int counter{0}; std::apply([&counter](auto&&... tasks) -> void { ((counter += tasks.return_value()), ...); }, output_tasks); @@ -491,7 +489,7 @@ TEST_CASE("io_scheduler multipler event waiters", "[io_scheduler]") tasks.emplace_back(func()); } - auto results = co_await coro::when_all_awaitable(tasks); + auto results = co_await coro::when_all(tasks); uint64_t counter{0}; for (const auto& task : results) @@ -506,7 +504,7 @@ TEST_CASE("io_scheduler multipler event waiters", "[io_scheduler]") e.set(s); }; - coro::sync_wait(coro::when_all_awaitable(spawn(), release())); + coro::sync_wait(coro::when_all(spawn(), release())); } TEST_CASE("io_scheduler self generating coroutine (stack overflow check)", "[io_scheduler]") diff --git a/test/test_mutex.cpp b/test/test_mutex.cpp index 32074f8..229fbd6 100644 --- a/test/test_mutex.cpp +++ b/test/test_mutex.cpp @@ -13,10 +13,18 @@ TEST_CASE("mutex single waiter not locked", "[mutex]") auto make_emplace_task = [&](coro::mutex& m) -> coro::task { std::cerr << "Acquiring lock\n"; - auto scoped_lock = co_await m.lock(); - std::cerr << "lock acquired, emplacing back 1\n"; - output.emplace_back(1); - std::cerr << "coroutine done\n"; + { + auto scoped_lock = co_await m.lock(); + REQUIRE_FALSE(m.try_lock()); + std::cerr << "lock acquired, emplacing back 1\n"; + output.emplace_back(1); + std::cerr << "coroutine done\n"; + } + + // The scoped lock should release the lock upon destructing. + REQUIRE(m.try_lock()); + m.unlock(); + co_return; }; @@ -43,6 +51,10 @@ TEST_CASE("mutex many waiters until event", "[mutex]") co_await tp.schedule(); std::cerr << "id = " << id << " waiting to acquire the lock\n"; auto scoped_lock = co_await m.lock(); + + // Should always be locked upon acquiring the locks. + REQUIRE_FALSE(m.try_lock()); + std::cerr << "id = " << id << " lock acquired\n"; value.fetch_add(1, std::memory_order::relaxed); std::cerr << "id = " << id << " coroutine done\n"; @@ -53,6 +65,7 @@ TEST_CASE("mutex many waiters until event", "[mutex]") co_await tp.schedule(); std::cerr << "block task acquiring lock\n"; auto scoped_lock = co_await m.lock(); + REQUIRE_FALSE(m.try_lock()); std::cerr << "block task acquired lock, waiting on event\n"; co_await e; co_return; @@ -76,7 +89,24 @@ TEST_CASE("mutex many waiters until event", "[mutex]") tasks.emplace_back(make_set_task()); - coro::sync_wait(coro::when_all_awaitable(tasks)); + coro::sync_wait(coro::when_all(tasks)); REQUIRE(value == 4); +} + +TEST_CASE("mutex scoped_lock unlock prior to scope exit", "[mutex]") +{ + coro::mutex m; + + auto make_task = [&]() -> coro::task { + { + auto lk = co_await m.lock(); + REQUIRE_FALSE(m.try_lock()); + lk.unlock(); + REQUIRE(m.try_lock()); + } + co_return; + }; + + coro::sync_wait(make_task()); } \ No newline at end of file diff --git a/test/test_thread_pool.cpp b/test/test_thread_pool.cpp index 46db68f..ab1b767 100644 --- a/test/test_thread_pool.cpp +++ b/test/test_thread_pool.cpp @@ -26,7 +26,7 @@ TEST_CASE("thread_pool one worker many tasks tuple", "[thread_pool]") co_return 50; }; - auto tasks = coro::sync_wait(coro::when_all_awaitable(f(), f(), f(), f(), f())); + auto tasks = coro::sync_wait(coro::when_all(f(), f(), f(), f(), f())); REQUIRE(std::tuple_size() == 5); uint64_t counter{0}; @@ -49,7 +49,7 @@ TEST_CASE("thread_pool one worker many tasks vector", "[thread_pool]") input_tasks.emplace_back(f()); input_tasks.emplace_back(f()); - auto output_tasks = coro::sync_wait(coro::when_all_awaitable(input_tasks)); + auto output_tasks = coro::sync_wait(coro::when_all(input_tasks)); REQUIRE(output_tasks.size() == 3); @@ -79,7 +79,7 @@ TEST_CASE("thread_pool N workers 100k tasks", "[thread_pool]") input_tasks.emplace_back(make_task(tp)); } - auto output_tasks = coro::sync_wait(coro::when_all_awaitable(input_tasks)); + auto output_tasks = coro::sync_wait(coro::when_all(input_tasks)); REQUIRE(output_tasks.size() == iterations); uint64_t counter{0}; @@ -189,5 +189,5 @@ TEST_CASE("thread_pool event jump threads", "[thread_pool]") co_return; }; - coro::sync_wait(coro::when_all_awaitable(make_tp1_task(), make_tp2_task())); + coro::sync_wait(coro::when_all(make_tp1_task(), make_tp2_task())); } \ No newline at end of file diff --git a/test/test_when_all.cpp b/test/test_when_all.cpp index 22861b8..e6af1dc 100644 --- a/test/test_when_all.cpp +++ b/test/test_when_all.cpp @@ -2,11 +2,11 @@ #include -TEST_CASE("when_all_awaitable single task with tuple container", "[when_all]") +TEST_CASE("when_all single task with tuple container", "[when_all]") { auto make_task = [](uint64_t amount) -> coro::task { co_return amount; }; - auto output_tasks = coro::sync_wait(coro::when_all_awaitable(make_task(100))); + auto output_tasks = coro::sync_wait(coro::when_all(make_task(100))); REQUIRE(std::tuple_size() == 1); uint64_t counter{0}; @@ -15,11 +15,11 @@ TEST_CASE("when_all_awaitable single task with tuple container", "[when_all]") REQUIRE(counter == 100); } -TEST_CASE("when_all_awaitable multiple tasks with tuple container", "[when_all]") +TEST_CASE("when_all multiple tasks with tuple container", "[when_all]") { auto make_task = [](uint64_t amount) -> coro::task { co_return amount; }; - auto output_tasks = coro::sync_wait(coro::when_all_awaitable(make_task(100), make_task(50), make_task(20))); + auto output_tasks = coro::sync_wait(coro::when_all(make_task(100), make_task(50), make_task(20))); REQUIRE(std::tuple_size() == 3); uint64_t counter{0}; @@ -28,14 +28,14 @@ TEST_CASE("when_all_awaitable multiple tasks with tuple container", "[when_all]" REQUIRE(counter == 170); } -TEST_CASE("when_all_awaitable single task with vector container", "[when_all]") +TEST_CASE("when_all single task with vector container", "[when_all]") { auto make_task = [](uint64_t amount) -> coro::task { co_return amount; }; std::vector> input_tasks; input_tasks.emplace_back(make_task(100)); - auto output_tasks = coro::sync_wait(coro::when_all_awaitable(input_tasks)); + auto output_tasks = coro::sync_wait(coro::when_all(input_tasks)); REQUIRE(output_tasks.size() == 1); uint64_t counter{0}; @@ -47,7 +47,7 @@ TEST_CASE("when_all_awaitable single task with vector container", "[when_all]") REQUIRE(counter == 100); } -TEST_CASE("when_all_ready multple task withs vector container", "[when_all]") +TEST_CASE("when_all multple task withs vector container", "[when_all]") { auto make_task = [](uint64_t amount) -> coro::task { co_return amount; }; @@ -57,7 +57,7 @@ TEST_CASE("when_all_ready multple task withs vector container", "[when_all]") input_tasks.emplace_back(make_task(550)); input_tasks.emplace_back(make_task(1000)); - auto output_tasks = coro::sync_wait(coro::when_all_awaitable(input_tasks)); + auto output_tasks = coro::sync_wait(coro::when_all(input_tasks)); REQUIRE(output_tasks.size() == 4); uint64_t counter{0};