diff --git a/.githooks/pre-commit b/.githooks/pre-commit new file mode 100755 index 0000000..8b2c874 --- /dev/null +++ b/.githooks/pre-commit @@ -0,0 +1,39 @@ +#!/usr/bin/sh + +FILE_EXTS=".c .h .cpp .hpp .cc .hh .cxx .tcc" + +# Determins if a file has the right extension to be clang-format'ed. +should_clang_format() { + local filename=$(basename "$1") + local extension=".${filename##*.}" + local ext + + local result=0 + + # Ignore the test/catch.hpp file + if [[ "$1" != *"catch"* ]]; then + for ext in $FILE_EXTS; do + # Otherwise, if the extension is in the array of extensions to reformat, echo 1. + [[ "$ext" == "$extension" ]] && result=1 && break + done + fi + + echo $result +} + +# Run the clang-format across the project's changed files. +for file in $(git diff-index --cached --name-only HEAD); do + if [ -f "${file}" ] && [ "$(should_clang_format "${file}")" != "0" ] ; then + echo "clang-format ${file}" + clang-format -i --style=file "${file}" + git add "${file}" + fi +done + +# Update the README.md example code with the given macros. +template_contents=$(cat '.githooks/readme-template.md') +# All code examples in markdown should be indeded ' ' +coro_event_cpp_contents=$(cat 'examples/coro_event.cpp') + +echo "${template_contents/\$\{EXAMPLE_CORO_EVENT_CPP\}/$coro_event_cpp_contents}" > README.md +git add README.md diff --git a/.githooks/readme-template.md b/.githooks/readme-template.md new file mode 100644 index 0000000..d261b81 --- /dev/null +++ b/.githooks/readme-template.md @@ -0,0 +1,150 @@ +# libcoro C++20 linux coroutine library + +[![CI](https://github.com/jbaldwin/libcoro/workflows/build/badge.svg)](https://github.com/jbaldwin/libcoro/workflows/build/badge.svg) +[![Coverage Status](https://coveralls.io/repos/github/jbaldwin/libcoro/badge.svg?branch=master)](https://coveralls.io/github/jbaldwin/libcoro?branch=master) +[![language][badge.language]][language] +[![license][badge.license]][license] + +**libcoro** is licensed under the Apache 2.0 license. + +**libcoro** is meant to provide low level coroutine constructs for building larger applications, the current focus is around high performance networking coroutine support. + +## Overview + * C++20 coroutines! + * Modern Safe C++20 API + * Higher level coroutine constructs + ** coro::task + ** coro::generator + ** coro::event + ** coro::latch + ** coro::mutex + ** coro::sync_wait(awaitable) + *** coro::when_all(awaitable...) + * Schedulers + ** coro::thread_pool for coroutine cooperative multitasking + ** coro::io_scheduler for driving i/o events, uses thread_pool + *** epoll driver implemented + *** io_uring driver planned (will be required for file i/o) + * Coroutine Networking + ** coro::net::dns_resolver for async dns, leverages libc-ares + ** coro::tcp_client and coro::tcp_server + ** coro::udp_peer + +### A note on co_await +Its important to note with coroutines that depending on the construct used _any_ `co_await` has the +potential to switch the thread that is executing the currently running coroutine. In general this shouldn't +affect the way any user of the library would write code except for `thread_local`. Usage of `thread_local` +should be extremely careful and _never_ used across any `co_await` boundary do to thread switching and +work stealing on thread pools. + +### coro::event +The `coro::event` is a thread safe async tool to have 1 or more waiters suspend for an event to be set +before proceeding. The implementation of event currently will resume execution of all waiters on the +thread that sets the event. If the event is already set when a waiter goes to wait on the thread they +will simply continue executing with no suspend or wait time incurred. + +```C++ +${EXAMPLE_CORO_EVENT_CPP} +``` + +Expected output: +```bash +$ ./Debug/examples/coro_event +task 1 is waiting on the event... +task 2 is waiting on the event... +task 3 is waiting on the event... +set task is triggering the event +task 3 event triggered, now resuming. +task 2 event triggered, now resuming. +task 1 event triggered, now resuming. +``` + +## Usage + +### Requirements + C++20 Compiler with coroutine support + g++10.2 is tested + CMake + make or ninja + pthreads + gcov/lcov (For generating coverage only) + +### Instructions + +#### Cloning the project +This project uses gitsubmodules, to properly checkout this project use: + + git clone --recurse-submodules + +This project depends on the following projects: + * [libc-ares](https://github.com/c-ares/c-ares) For async DNS resolver. + +#### Building + mkdir Release && cd Release + cmake -DCMAKE_BUILD_TYPE=Release .. + cmake --build . + +CMake Options: + +| Name | Default | Description | +|:-----------------------|:--------|:--------------------------------------------------------------| +| LIBCORO_BUILD_TESTS | ON | Should the tests be built? | +| LIBCORO_CODE_COVERAGE | OFF | Should code coverage be enabled? Requires tests to be enabled | +| LIBCORO_BUILD_EXAMPLES | ON | Should the examples be built? | + +#### Adding to your project + +##### add_subdirectory() + +```cmake +# Include the checked out libcoro code in your CMakeLists.txt file +add_subdirectory(path/to/libcoro) + +# Link the libcoro cmake target to your project(s). +target_link_libraries(${PROJECT_NAME} PUBLIC libcoro) + +``` + +##### FetchContent +CMake can include the project directly by downloading the source, compiling and linking to your project via FetchContent, below is an example on how you might do this within your project. + + +```cmake +cmake_minimum_required(VERSION 3.11) + +# Fetch the project and make it available for use. +include(FetchContent) +FetchContent_Declare( + libcoro + GIT_REPOSITORY https://github.com/jbaldwin/libcoro.git + GIT_TAG +) +FetchContent_MakeAvailable(libcoro) + +# Link the libcoro cmake target to your project(s). +target_link_libraries(${PROJECT_NAME} PUBLIC libcoro) + +``` + +#### Tests +The tests will automatically be run by github actions on creating a pull request. They can also be ran locally: + + # Invoke via cmake: + ctest -VV + + # Or invoke directly, can pass the name of tests to execute, the framework used is catch2 + # catch2 supports '*' wildcards to run multiple tests or comma delimited ',' test names. + # The below will run all tests with "tcp_server" prefix in their test name. + ./Debug/test/libcoro_test "tcp_server*" + +### Support + +File bug reports, feature requests and questions using [GitHub libcoro Issues](https://github.com/jbaldwin/libcoro/issues) + +Copyright © 2020-2021 Josh Baldwin + +[badge.language]: https://img.shields.io/badge/language-C%2B%2B20-yellow.svg +[badge.license]: https://img.shields.io/badge/license-Apache--2.0-blue + +[language]: https://en.wikipedia.org/wiki/C%2B%2B17 +[license]: https://en.wikipedia.org/wiki/Apache_License diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ee6b286..0e87888 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -66,7 +66,7 @@ jobs: cd build-debug-g++ cmake \ -GNinja \ - -DCORO_CODE_COVERAGE=ON \ + -DLIBCORO_CODE_COVERAGE=ON \ -DCMAKE_BUILD_TYPE=Debug \ -DCMAKE_C_COMPILER=gcc \ -DCMAKE_CXX_COMPILER=g++ \ diff --git a/CMakeLists.txt b/CMakeLists.txt index edabb7c..a16544a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,11 +1,20 @@ cmake_minimum_required(VERSION 3.16) -project(coro CXX) +project(libcoro CXX) -option(CORO_BUILD_TESTS "Build the tests, Default=ON." ON) -option(CORO_CODE_COVERAGE "Enable code coverage, tests must also be enabled, Default=OFF" OFF) +# Set the githooks directory to auto format and update the readme. +message("git config core.hooksPath .githooks") +execute_process( + COMMAND git config core.hooksPath .githooks + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) -message("${PROJECT_NAME} CORO_BUILD_TESTS = ${CORO_BUILD_TESTS}") -message("${PROJECT_NAME} CORO_CODE_COVERAGE = ${CORO_CODE_COVERAGE}") +option(LIBCORO_BUILD_TESTS "Build the tests, Default=ON." ON) +option(LIBCORO_CODE_COVERAGE "Enable code coverage, tests must also be enabled, Default=OFF" OFF) +option(LIBCORO_BUILD_EXAMPLES "Build the examples, Default=ON." ON) + +message("${PROJECT_NAME} LIBCORO_BUILD_TESTS = ${LIBCORO_BUILD_TESTS}") +message("${PROJECT_NAME} LIBCORO_CODE_COVERAGE = ${LIBCORO_CODE_COVERAGE}") +message("${PROJECT_NAME} LIBCORO_BUILD_EXAMPLES = ${LIBCORO_BUILD_EXAMPLES}") set(CARES_STATIC ON CACHE INTERNAL "") set(CARES_SHARED OFF CACHE INTERNAL "") @@ -36,6 +45,7 @@ set(LIBCORO_SOURCE_FILES inc/coro/generator.hpp inc/coro/io_scheduler.hpp src/io_scheduler.cpp inc/coro/latch.hpp + inc/coro/mutex.hpp src/mutex.cpp inc/coro/poll.hpp inc/coro/shutdown.hpp inc/coro/sync_wait.hpp src/sync_wait.cpp @@ -60,13 +70,16 @@ elseif(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") message(FATAL_ERROR "Clang is currently not supported.") endif() - -if(CORO_BUILD_TESTS) - if(CORO_CODE_COVERAGE) +if(LIBCORO_BUILD_TESTS) + if(LIBCORO_CODE_COVERAGE) target_compile_options(${PROJECT_NAME} PRIVATE --coverage) target_link_libraries(${PROJECT_NAME} PRIVATE gcov) endif() enable_testing() add_subdirectory(test) +endif() + +if(LIBCORO_BUILD_EXAMPLES) + add_subdirectory(examples) endif() \ No newline at end of file diff --git a/Makefile b/Makefile index a22774c..34b3f00 100644 --- a/Makefile +++ b/Makefile @@ -1,42 +1,4 @@ -.DEFAULT_GOAL := debug - -# Builds the project and tests in the Debug directory. -debug: - @$(MAKE) compile BUILD_TYPE=Debug --no-print-directory - -# Builds the project and tests in the RelWithDebInfo directory. -release-with-debug-info: - @$(MAKE) compile BUILD_TYPE=RelWithDebInfo --no-print-directory - -# Builds the project and tests in the Release directory. -release: - @$(MAKE) compile BUILD_TYPE=Release --no-print-directory - # Internal target for all build targets to call. -compile: - mkdir -p ${BUILD_TYPE}; \ - cd ${BUILD_TYPE}; \ - cmake -DCMAKE_BUILD_TYPE=${BUILD_TYPE} ..; \ - cmake --build . -- -j $(nproc) - -# Run Debug tests. -debug-test: - @$(MAKE) test BUILD_TYPE=Debug --no-print-directory - -# Run RelWithDebInfo tests. -release-with-debug-info-test: - @$(MAKE) test BUILD_TYPE=RelWithDebInfo --no-print-directory - -# Run Release tests. -release-test: - @$(MAKE) test BUILD_TYPE=Release --no-print-directory - -# Internal target for all test targets to call. -.PHONY: test -test: - cd ${BUILD_TYPE}; \ - ctest -VV - # Cleans all build types. .PHONY: clean clean: @@ -45,6 +7,7 @@ clean: rm -rf Release # Runs clang-format with the project's .clang-format. +.PHONY: format format: # Inlcude *.hpp|*.h|*.cpp but ignore catch lib as well as RelWithDebInfo|Release|Debug|build find . \( -name '*.hpp' -or -name '*.h' -or -name '*.cpp' \) \ @@ -53,4 +16,5 @@ format: -and -not -iwholename '*/Release/*' \ -and -not -iwholename '*/Debug/*' \ -and -not -iwholename '*/build/*' \ + -and -not -iwholename '*/vendor/*' \ -exec clang-format -i --style=file {} \; diff --git a/README.md b/README.md index 44221a1..117bc46 100644 --- a/README.md +++ b/README.md @@ -1,38 +1,176 @@ -libcoro C++20 Coroutines -======================== +# libcoro C++20 linux coroutine library [![CI](https://github.com/jbaldwin/libcoro/workflows/build/badge.svg)](https://github.com/jbaldwin/libcoro/workflows/build/badge.svg) [![Coverage Status](https://coveralls.io/repos/github/jbaldwin/libcoro/badge.svg?branch=master)](https://coveralls.io/github/jbaldwin/libcoro?branch=master) [![language][badge.language]][language] [![license][badge.license]][license] +**libcoro** is licensed under the Apache 2.0 license. + +**libcoro** is meant to provide low level coroutine constructs for building larger applications, the current focus is around high performance networking coroutine support. + +## Overview + * C++20 coroutines! + * Modern Safe C++20 API + * Higher level coroutine constructs + ** coro::task + ** coro::generator + ** coro::event + ** coro::latch + ** coro::mutex + ** coro::sync_wait(awaitable) + *** coro::when_all(awaitable...) + * Schedulers + ** coro::thread_pool for coroutine cooperative multitasking + ** coro::io_scheduler for driving i/o events, uses thread_pool + *** epoll driver implemented + *** io_uring driver planned (will be required for file i/o) + * Coroutine Networking + ** coro::net::dns_resolver for async dns, leverages libc-ares + ** coro::tcp_client and coro::tcp_server + ** coro::udp_peer + +### A note on co_await +Its important to note with coroutines that depending on the construct used _any_ `co_await` has the +potential to switch the thread that is executing the currently running coroutine. In general this shouldn't +affect the way any user of the library would write code except for `thread_local`. Usage of `thread_local` +should be extremely careful and _never_ used across any `co_await` boundary do to thread switching and +work stealing on thread pools. + +### coro::event +The `coro::event` is a thread safe async tool to have 1 or more waiters suspend for an event to be set +before proceeding. The implementation of event currently will resume execution of all waiters on the +thread that sets the event. If the event is already set when a waiter goes to wait on the thread they +will simply continue executing with no suspend or wait time incurred. + +```C++ +#include +#include + +int main() +{ + coro::event e; + + // This task will wait until the given event has been set before advancings + auto make_wait_task = [](const coro::event& e, int i) -> coro::task { + std::cout << "task " << i << " is waiting on the event...\n"; + co_await e; + std::cout << "task " << i << " event triggered, now resuming.\n"; + co_return i; + }; + + // This task will trigger the event allowing all waiting tasks to proceed. + auto make_set_task = [](coro::event& e) -> coro::task { + std::cout << "set task is triggering the event\n"; + e.set(); + co_return; + }; + + // Synchronously wait until all the tasks are completed, this is intentionally + // starting the first 3 wait tasks prior to the final set task. + coro::sync_wait( + coro::when_all_awaitable(make_wait_task(e, 1), make_wait_task(e, 2), make_wait_task(e, 3), make_set_task(e))); +} +``` + +Expected output: +```bash +$ ./Debug/examples/coro_event +task 1 is waiting on the event... +task 2 is waiting on the event... +task 3 is waiting on the event... +set task is triggering the event +task 3 event triggered, now resuming. +task 2 event triggered, now resuming. +task 1 event triggered, now resuming. +``` + +## Usage + +### Requirements + C++20 Compiler with coroutine support + g++10.2 is tested + CMake + make or ninja + pthreads + gcov/lcov (For generating coverage only) + +### Instructions + +#### Cloning the project +This project uses gitsubmodules, to properly checkout this project use: + + git clone --recurse-submodules + +This project depends on the following projects: + * [libc-ares](https://github.com/c-ares/c-ares) For async DNS resolver. + +#### Building + mkdir Release && cd Release + cmake -DCMAKE_BUILD_TYPE=Release .. + cmake --build . + +CMake Options: + +| Name | Default | Description | +|:-----------------------|:--------|:--------------------------------------------------------------| +| LIBCORO_BUILD_TESTS | ON | Should the tests be built? | +| LIBCORO_CODE_COVERAGE | OFF | Should code coverage be enabled? Requires tests to be enabled | +| LIBCORO_BUILD_EXAMPLES | ON | Should the examples be built? | + +#### Adding to your project + +##### add_subdirectory() + +```cmake +# Include the checked out libcoro code in your CMakeLists.txt file +add_subdirectory(path/to/libcoro) + +# Link the libcoro cmake target to your project(s). +target_link_libraries(${PROJECT_NAME} PUBLIC libcoro) + +``` + +##### FetchContent +CMake can include the project directly by downloading the source, compiling and linking to your project via FetchContent, below is an example on how you might do this within your project. + + +```cmake +cmake_minimum_required(VERSION 3.11) + +# Fetch the project and make it available for use. +include(FetchContent) +FetchContent_Declare( + libcoro + GIT_REPOSITORY https://github.com/jbaldwin/libcoro.git + GIT_TAG +) +FetchContent_MakeAvailable(libcoro) + +# Link the libcoro cmake target to your project(s). +target_link_libraries(${PROJECT_NAME} PUBLIC libcoro) + +``` + +#### Tests +The tests will automatically be run by github actions on creating a pull request. They can also be ran locally: + + # Invoke via cmake: + ctest -VV + + # Or invoke directly, can pass the name of tests to execute, the framework used is catch2 + # catch2 supports '*' wildcards to run multiple tests or comma delimited ',' test names. + # The below will run all tests with "tcp_server" prefix in their test name. + ./Debug/test/libcoro_test "tcp_server*" + +### Support + +File bug reports, feature requests and questions using [GitHub libcoro Issues](https://github.com/jbaldwin/libcoro/issues) + +Copyright © 2020-2021 Josh Baldwin + [badge.language]: https://img.shields.io/badge/language-C%2B%2B20-yellow.svg [badge.license]: https://img.shields.io/badge/license-Apache--2.0-blue [language]: https://en.wikipedia.org/wiki/C%2B%2B17 [license]: https://en.wikipedia.org/wiki/Apache_License - -**libcoro** is licensed under the Apache 2.0 license. - -# Background -Libcoro is a C++20 coroutine library. So far most inspiration has been gleaned from [libcppcoro](https://github.com/lewissbaker/cppcoro) an amazing C++ coroutine library as well as Lewis Baker's great coroutine blog entries [https://lewissbaker.github.io/](Blog). I would highly recommend anyone who is trying to learn the internals of C++20's coroutine implementation to read all of his blog entries, they are extremely insightful and well written. - -# Goal -Libcoro is currently more of a learning experience for myself but ultimately I'd like to turn this into a great linux coroutine base library with an easy to use HTTP scheduler/server. - -# Building -There is a root makefile with various commands to help make building and running tests on this project easier. - -```bash -# Build targets -make debug|release-with-debug-info|release - -# Run tests targets -make debug-test|release-with-debug-info-tests|release-tests - -# Clean all builds. -make clean - -# clang-format the code -make format -``` diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt new file mode 100644 index 0000000..95b2071 --- /dev/null +++ b/examples/CMakeLists.txt @@ -0,0 +1,14 @@ +cmake_minimum_required(VERSION 3.16) +project(libcoro_examples) + +add_executable(coro_event coro_event.cpp) +target_compile_features(coro_event PUBLIC cxx_std_20) +target_link_libraries(coro_event PUBLIC libcoro) + +if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU") + target_compile_options(coro_event PUBLIC -fcoroutines -Wall -Wextra -pipe) +elseif(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + message(FATAL_ERROR "Clang is currently not supported.") +else() + message(FATAL_ERROR "Unsupported compiler.") +endif() \ No newline at end of file diff --git a/examples/coro_event.cpp b/examples/coro_event.cpp new file mode 100644 index 0000000..c47f4c2 --- /dev/null +++ b/examples/coro_event.cpp @@ -0,0 +1,27 @@ +#include +#include + +int main() +{ + coro::event e; + + // This task will wait until the given event has been set before advancings + auto make_wait_task = [](const coro::event& e, int i) -> coro::task { + std::cout << "task " << i << " is waiting on the event...\n"; + co_await e; + std::cout << "task " << i << " event triggered, now resuming.\n"; + co_return i; + }; + + // This task will trigger the event allowing all waiting tasks to proceed. + auto make_set_task = [](coro::event& e) -> coro::task { + std::cout << "set task is triggering the event\n"; + e.set(); + co_return; + }; + + // Synchronously wait until all the tasks are completed, this is intentionally + // starting the first 3 wait tasks prior to the final set task. + coro::sync_wait( + coro::when_all_awaitable(make_wait_task(e, 1), make_wait_task(e, 2), make_wait_task(e, 3), make_set_task(e))); +} diff --git a/inc/coro/concepts/buffer.hpp b/inc/coro/concepts/buffer.hpp index e1c90a7..07ea95c 100644 --- a/inc/coro/concepts/buffer.hpp +++ b/inc/coro/concepts/buffer.hpp @@ -1,12 +1,11 @@ #pragma once #include -#include #include +#include namespace coro::concepts { - // clang-format off template concept const_buffer = requires(const type t) diff --git a/inc/coro/concepts/promise.hpp b/inc/coro/concepts/promise.hpp index c3e714a..fa6e7e0 100644 --- a/inc/coro/concepts/promise.hpp +++ b/inc/coro/concepts/promise.hpp @@ -6,7 +6,6 @@ namespace coro::concepts { - // clang-format off template concept promise = requires(type t) diff --git a/inc/coro/coro.hpp b/inc/coro/coro.hpp index e84bafa..a557615 100644 --- a/inc/coro/coro.hpp +++ b/inc/coro/coro.hpp @@ -19,6 +19,7 @@ #include "coro/generator.hpp" #include "coro/io_scheduler.hpp" #include "coro/latch.hpp" +#include "coro/mutex.hpp" #include "coro/sync_wait.hpp" #include "coro/task.hpp" #include "coro/thread_pool.hpp" diff --git a/inc/coro/io_scheduler.hpp b/inc/coro/io_scheduler.hpp index 22c5fe5..a084f1a 100644 --- a/inc/coro/io_scheduler.hpp +++ b/inc/coro/io_scheduler.hpp @@ -1,9 +1,9 @@ #pragma once #include "coro/concepts/awaitable.hpp" +#include "coro/net/socket.hpp" #include "coro/poll.hpp" #include "coro/shutdown.hpp" -#include "coro/net/socket.hpp" #include "coro/task.hpp" #include @@ -54,7 +54,9 @@ public: auto await_ready() const noexcept -> bool { return m_token.is_set(); } auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool; - auto await_resume() noexcept { /* no-op */ } + auto await_resume() noexcept + { /* no-op */ + } const resume_token_base& m_token; std::coroutine_handle<> m_awaiting_coroutine; @@ -350,10 +352,7 @@ public: auto poll(fd_t fd, poll_op op, std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task; - auto poll( - const net::socket& sock, - poll_op op, - std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) + auto poll(const net::socket& sock, poll_op op, std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task; /** @@ -370,9 +369,10 @@ public: -> coro::task>; auto read( - const net::socket& sock, + const net::socket& sock, std::span buffer, - std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task>; + std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) + -> coro::task>; /** * This function will first poll the given `fd` to make sure it can be written to. Once notified @@ -388,9 +388,10 @@ public: -> coro::task>; auto write( - const net::socket& sock, + const net::socket& sock, const std::span buffer, - std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task>; + std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) + -> coro::task>; /** * Immediately yields the current task and places it at the end of the queue of tasks waiting @@ -580,10 +581,10 @@ private: auto resume(std::coroutine_handle<> handle) -> void; - static const constexpr std::chrono::milliseconds m_default_timeout{1000}; - static const constexpr std::chrono::milliseconds m_no_timeout{0}; - static const constexpr std::size_t m_max_events = 8; - std::array m_events{}; + static const constexpr std::chrono::milliseconds m_default_timeout{1000}; + static const constexpr std::chrono::milliseconds m_no_timeout{0}; + static const constexpr std::size_t m_max_events = 8; + std::array m_events{}; auto process_task_and_start(task& task) -> void; auto process_task_variant(task_variant& tv) -> void; diff --git a/inc/coro/mutex.hpp b/inc/coro/mutex.hpp new file mode 100644 index 0000000..50c7b33 --- /dev/null +++ b/inc/coro/mutex.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include +#include +#include +#include + +namespace coro +{ +class mutex +{ +public: + struct scoped_lock + { + friend class mutex; + + scoped_lock(mutex& m) : m_mutex(m) {} + ~scoped_lock() { m_mutex.unlock(); } + + mutex& m_mutex; + }; + + struct awaiter + { + awaiter(mutex& m) noexcept : m_mutex(m) {} + + auto await_ready() const noexcept -> bool; + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool; + auto await_resume() noexcept -> scoped_lock; + + mutex& m_mutex; + std::coroutine_handle<> m_awaiting_coroutine; + }; + + explicit mutex() noexcept = default; + ~mutex() = default; + + mutex(const mutex&) = delete; + mutex(mutex&&) = delete; + auto operator=(const mutex&) -> mutex& = delete; + auto operator=(mutex&&) -> mutex& = delete; + + auto lock() -> awaiter; + auto try_lock() -> bool; + auto unlock() -> void; + +private: + friend class scoped_lock; + + std::atomic m_state{false}; + std::mutex m_waiter_mutex{}; + std::deque m_waiter_list{}; +}; + +} // namespace coro diff --git a/inc/coro/net/dns_resolver.hpp b/inc/coro/net/dns_resolver.hpp index b5908ba..97a8acd 100644 --- a/inc/coro/net/dns_resolver.hpp +++ b/inc/coro/net/dns_resolver.hpp @@ -1,24 +1,23 @@ #pragma once #include "coro/io_scheduler.hpp" -#include "coro/net/ip_address.hpp" #include "coro/net/hostname.hpp" +#include "coro/net/ip_address.hpp" #include "coro/task.hpp" #include -#include -#include -#include #include -#include #include -#include +#include +#include +#include #include +#include +#include namespace coro::net { - class dns_resolver; enum class dns_status @@ -30,6 +29,7 @@ enum class dns_status class dns_result { friend dns_resolver; + public: explicit dns_result(coro::resume_token& token, uint64_t pending_dns_requests); ~dns_result() = default; @@ -44,18 +44,14 @@ public: * were resolved from the hostname. */ auto ip_addresses() const -> const std::vector& { return m_ip_addresses; } + private: - coro::resume_token& m_token; - uint64_t m_pending_dns_requests{0}; - dns_status m_status{dns_status::complete}; + coro::resume_token& m_token; + uint64_t m_pending_dns_requests{0}; + dns_status m_status{dns_status::complete}; std::vector m_ip_addresses{}; - friend auto ares_dns_callback( - void* arg, - int status, - int timeouts, - struct hostent* host - ) -> void; + friend auto ares_dns_callback(void* arg, int status, int timeouts, struct hostent* host) -> void; }; class dns_resolver @@ -63,7 +59,7 @@ class dns_resolver public: explicit dns_resolver(io_scheduler& scheduler, std::chrono::milliseconds timeout); dns_resolver(const dns_resolver&) = delete; - dns_resolver(dns_resolver&&) = delete; + dns_resolver(dns_resolver&&) = delete; auto operator=(const dns_resolver&) noexcept -> dns_resolver& = delete; auto operator=(dns_resolver&&) noexcept -> dns_resolver& = delete; ~dns_resolver(); @@ -72,6 +68,7 @@ public: * @param hn The hostname to resolve its ip addresses. */ auto host_by_name(const net::hostname& hn) -> coro::task>; + private: /// The io scheduler to drive the events for dns lookups. io_scheduler& m_scheduler; diff --git a/inc/coro/net/hostname.hpp b/inc/coro/net/hostname.hpp index 8b74048..3de016a 100644 --- a/inc/coro/net/hostname.hpp +++ b/inc/coro/net/hostname.hpp @@ -4,25 +4,21 @@ namespace coro::net { - class hostname { public: hostname() = default; - explicit hostname(std::string hn) - : m_hostname(std::move(hn)) {} + explicit hostname(std::string hn) : m_hostname(std::move(hn)) {} hostname(const hostname&) = default; - hostname(hostname&&) = default; + hostname(hostname&&) = default; auto operator=(const hostname&) noexcept -> hostname& = default; auto operator=(hostname&&) noexcept -> hostname& = default; - ~hostname() = default; + ~hostname() = default; auto data() const -> const std::string& { return m_hostname; } - auto operator<=>(const hostname& other) const - { - return m_hostname <=> other.m_hostname; - } + auto operator<=>(const hostname& other) const { return m_hostname <=> other.m_hostname; } + private: std::string m_hostname; }; diff --git a/inc/coro/net/ip_address.hpp b/inc/coro/net/ip_address.hpp index 5ee5dbe..531bc16 100644 --- a/inc/coro/net/ip_address.hpp +++ b/inc/coro/net/ip_address.hpp @@ -1,15 +1,14 @@ #pragma once -#include -#include -#include -#include #include +#include #include +#include +#include +#include namespace coro::net { - enum class domain_t : int { ipv4 = AF_INET, @@ -25,14 +24,13 @@ public: static const constexpr size_t ipv6_len{16}; ip_address() = default; - ip_address(std::span binary_address, domain_t domain = domain_t::ipv4) - : m_domain(domain) + ip_address(std::span binary_address, domain_t domain = domain_t::ipv4) : m_domain(domain) { - if(m_domain == domain_t::ipv4 && binary_address.size() > ipv4_len) + if (m_domain == domain_t::ipv4 && binary_address.size() > ipv4_len) { throw std::runtime_error{"coro::net::ip_address provided binary ip address is too long"}; } - else if(binary_address.size() > ipv6_len) + else if (binary_address.size() > ipv6_len) { throw std::runtime_error{"coro::net::ip_address provided binary ip address is too long"}; } @@ -40,15 +38,15 @@ public: std::copy(binary_address.begin(), binary_address.end(), m_data.begin()); } ip_address(const ip_address&) = default; - ip_address(ip_address&&) = default; + ip_address(ip_address&&) = default; auto operator=(const ip_address&) noexcept -> ip_address& = default; auto operator=(ip_address&&) noexcept -> ip_address& = default; - ~ip_address() = default; + ~ip_address() = default; auto domain() const -> domain_t { return m_domain; } auto data() const -> std::span { - if(m_domain == domain_t::ipv4) + if (m_domain == domain_t::ipv4) { return std::span{m_data.data(), ipv4_len}; } @@ -64,7 +62,7 @@ public: addr.m_domain = domain; auto success = inet_pton(static_cast(addr.m_domain), address.data(), addr.m_data.data()); - if(success != 1) + if (success != 1) { throw std::runtime_error{"coro::net::ip_address faild to convert from string"}; } @@ -75,7 +73,7 @@ public: auto to_string() const -> std::string { std::string output; - if(m_domain == domain_t::ipv4) + if (m_domain == domain_t::ipv4) { output.resize(INET_ADDRSTRLEN, '\0'); } @@ -85,7 +83,7 @@ public: } auto success = inet_ntop(static_cast(m_domain), m_data.data(), output.data(), output.length()); - if(success != nullptr) + if (success != nullptr) { auto len = strnlen(success, output.length()); output.resize(len); @@ -101,7 +99,7 @@ public: auto operator<=>(const ip_address& other) const = default; private: - domain_t m_domain{domain_t::ipv4}; + domain_t m_domain{domain_t::ipv4}; std::array m_data{}; }; diff --git a/inc/coro/net/recv_status.hpp b/inc/coro/net/recv_status.hpp index b04a832..12ed9f3 100644 --- a/inc/coro/net/recv_status.hpp +++ b/inc/coro/net/recv_status.hpp @@ -1,29 +1,28 @@ #pragma once -#include #include +#include #include namespace coro::net { - enum class recv_status : int64_t { ok = 0, /// The peer closed the socket. closed = -1, /// The udp socket has not been bind()'ed to a local port. - udp_not_bound = -2, - try_again = EAGAIN, - would_block = EWOULDBLOCK, + udp_not_bound = -2, + try_again = EAGAIN, + would_block = EWOULDBLOCK, bad_file_descriptor = EBADF, - connection_refused = ECONNREFUSED, - memory_fault = EFAULT, - interrupted = EINTR, - invalid_argument = EINVAL, - no_memory = ENOMEM, - not_connected = ENOTCONN, - not_a_socket = ENOTSOCK + connection_refused = ECONNREFUSED, + memory_fault = EFAULT, + interrupted = EINTR, + invalid_argument = EINVAL, + no_memory = ENOMEM, + not_connected = ENOTCONN, + not_a_socket = ENOTSOCK }; auto to_string(recv_status status) -> const std::string&; diff --git a/inc/coro/net/send_status.hpp b/inc/coro/net/send_status.hpp index ae77691..67b672c 100644 --- a/inc/coro/net/send_status.hpp +++ b/inc/coro/net/send_status.hpp @@ -1,31 +1,30 @@ #pragma once -#include #include +#include namespace coro::net { - enum class send_status : int64_t { - ok = 0, - permission_denied = EACCES, - try_again = EAGAIN, - would_block = EWOULDBLOCK, - already_in_progress = EALREADY, - bad_file_descriptor = EBADF, - connection_reset = ECONNRESET, - no_peer_address = EDESTADDRREQ, - memory_fault = EFAULT, - interrupted = EINTR, - is_connection = EISCONN, - message_size = EMSGSIZE, - output_queue_full = ENOBUFS, - no_memory = ENOMEM, - not_connected = ENOTCONN, - not_a_socket = ENOTSOCK, + ok = 0, + permission_denied = EACCES, + try_again = EAGAIN, + would_block = EWOULDBLOCK, + already_in_progress = EALREADY, + bad_file_descriptor = EBADF, + connection_reset = ECONNRESET, + no_peer_address = EDESTADDRREQ, + memory_fault = EFAULT, + interrupted = EINTR, + is_connection = EISCONN, + message_size = EMSGSIZE, + output_queue_full = ENOBUFS, + no_memory = ENOMEM, + not_connected = ENOTCONN, + not_a_socket = ENOTSOCK, operationg_not_supported = EOPNOTSUPP, - pipe_closed = EPIPE + pipe_closed = EPIPE }; } // namespace coro::net diff --git a/inc/coro/net/socket.hpp b/inc/coro/net/socket.hpp index 0a615e9..570a301 100644 --- a/inc/coro/net/socket.hpp +++ b/inc/coro/net/socket.hpp @@ -13,7 +13,6 @@ namespace coro::net { - class socket { public: @@ -36,9 +35,9 @@ public: struct options { /// The domain for the socket. - domain_t domain; + domain_t domain; /// The type of socket. - type_t type; + type_t type; /// If the socket should be blocking or non-blocking. blocking_t blocking; }; @@ -51,7 +50,7 @@ public: socket(const socket&) = delete; socket(socket&& other) : m_fd(std::exchange(other.m_fd, -1)) {} auto operator=(const socket&) -> socket& = delete; - auto operator=(socket&& other) noexcept -> socket&; + auto operator =(socket&& other) noexcept -> socket&; ~socket() { close(); } @@ -105,9 +104,6 @@ auto make_socket(const socket::options& opts) -> socket; * for udp types. */ auto make_accept_socket( - const socket::options& opts, - const net::ip_address& address, - uint16_t port, - int32_t backlog = 128) -> socket; + const socket::options& opts, const net::ip_address& address, uint16_t port, int32_t backlog = 128) -> socket; } // namespace coro::net diff --git a/inc/coro/net/tcp_client.hpp b/inc/coro/net/tcp_client.hpp index 58367f9..ca35460 100644 --- a/inc/coro/net/tcp_client.hpp +++ b/inc/coro/net/tcp_client.hpp @@ -1,21 +1,20 @@ #pragma once #include "coro/concepts/buffer.hpp" +#include "coro/io_scheduler.hpp" +#include "coro/net/connect.hpp" #include "coro/net/ip_address.hpp" #include "coro/net/recv_status.hpp" #include "coro/net/send_status.hpp" #include "coro/net/socket.hpp" -#include "coro/net/connect.hpp" #include "coro/poll.hpp" #include "coro/task.hpp" -#include "coro/io_scheduler.hpp" #include -#include -#include #include -#include +#include #include +#include #include namespace coro @@ -25,7 +24,6 @@ class io_scheduler; namespace coro::net { - class tcp_server; class tcp_client @@ -48,9 +46,7 @@ public: */ tcp_client( io_scheduler& scheduler, - options opts = options{ - .address = {net::ip_address::from_string("127.0.0.1")}, - .port = 8080}); + options opts = options{.address = {net::ip_address::from_string("127.0.0.1")}, .port = 8080}); tcp_client(const tcp_client&) = delete; tcp_client(tcp_client&&) = default; auto operator=(const tcp_client&) noexcept -> tcp_client& = delete; @@ -71,8 +67,7 @@ public: * @param timeout How long to wait for the connection to establish? Timeout of zero is indefinite. * @return The result status of trying to connect. */ - auto connect( - std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task; + auto connect(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task; /** * Polls for the given operation on this client's tcp socket. This should be done prior to @@ -82,9 +77,8 @@ public: * @return The status result of th poll operation. When poll_status::event is returned then the * event operation is ready. */ - auto poll( - coro::poll_op op, - std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task + auto poll(coro::poll_op op, std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) + -> coro::task { co_return co_await m_io_scheduler.poll(m_socket, op, timeout); } @@ -97,17 +91,16 @@ public: * bytes will be a subspan or full span of the given input buffer. */ template - auto recv( - buffer_type&& buffer) -> std::pair> + auto recv(buffer_type&& buffer) -> std::pair> { // If the user requested zero bytes, just return. - if(buffer.empty()) + if (buffer.empty()) { return {recv_status::ok, std::span{}}; } auto bytes_recv = ::recv(m_socket.native_handle(), buffer.data(), buffer.size(), 0); - if(bytes_recv > 0) + if (bytes_recv > 0) { // Ok, we've recieved some data. return {recv_status::ok, std::span{buffer.data(), static_cast(bytes_recv)}}; @@ -134,17 +127,16 @@ public: * were successfully sent the status will be 'ok' and the remaining span will be empty. */ template - auto send( - const buffer_type& buffer) -> std::pair> + auto send(const buffer_type& buffer) -> std::pair> { // If the user requested zero bytes, just return. - if(buffer.empty()) + if (buffer.empty()) { return {send_status::ok, std::span{buffer.data(), buffer.size()}}; } auto bytes_sent = ::send(m_socket.native_handle(), buffer.data(), buffer.size(), 0); - if(bytes_sent >= 0) + if (bytes_sent >= 0) { // Some or all of the bytes were written. return {send_status::ok, std::span{buffer.data() + bytes_sent, buffer.size() - bytes_sent}}; @@ -159,10 +151,7 @@ public: private: /// The tcp_server creates already connected clients and provides a tcp socket pre-built. friend tcp_server; - tcp_client( - io_scheduler& scheduler, - net::socket socket, - options opts); + tcp_client(io_scheduler& scheduler, net::socket socket, options opts); /// The scheduler that will drive this tcp client. io_scheduler& m_io_scheduler; diff --git a/inc/coro/net/tcp_server.hpp b/inc/coro/net/tcp_server.hpp index 6a18f5d..7147590 100644 --- a/inc/coro/net/tcp_server.hpp +++ b/inc/coro/net/tcp_server.hpp @@ -1,9 +1,9 @@ #pragma once -#include "coro/net/ip_address.hpp" -#include "coro/net/tcp_client.hpp" #include "coro/io_scheduler.hpp" +#include "coro/net/ip_address.hpp" #include "coro/net/socket.hpp" +#include "coro/net/tcp_client.hpp" #include "coro/task.hpp" #include @@ -18,26 +18,22 @@ public: struct options { /// The ip address for the tcp server to bind and listen on. - net::ip_address address{net::ip_address::from_string("0.0.0.0")}; + net::ip_address address{net::ip_address::from_string("0.0.0.0")}; /// The port for the tcp server to bind and listen on. - uint16_t port{8080}; + uint16_t port{8080}; /// The kernel backlog of connections to buffer. - int32_t backlog{128}; + int32_t backlog{128}; }; tcp_server( io_scheduler& scheduler, - options opts = - options{ - .address = net::ip_address::from_string("0.0.0.0"), - .port = 8080, - .backlog = 128}); + options opts = options{.address = net::ip_address::from_string("0.0.0.0"), .port = 8080, .backlog = 128}); tcp_server(const tcp_server&) = delete; tcp_server(tcp_server&&) = delete; auto operator=(const tcp_server&) -> tcp_server& = delete; auto operator=(tcp_server&&) -> tcp_server& = delete; - ~tcp_server() = default; + ~tcp_server() = default; /** * Polls for new incoming tcp connections. diff --git a/inc/coro/net/udp_peer.hpp b/inc/coro/net/udp_peer.hpp index 74f3d85..c70fa27 100644 --- a/inc/coro/net/udp_peer.hpp +++ b/inc/coro/net/udp_peer.hpp @@ -1,16 +1,16 @@ #pragma once #include "coro/concepts/buffer.hpp" -#include "coro/net/ip_address.hpp" -#include "coro/net/socket.hpp" -#include "coro/net/send_status.hpp" -#include "coro/net/recv_status.hpp" -#include "coro/task.hpp" #include "coro/io_scheduler.hpp" +#include "coro/net/ip_address.hpp" +#include "coro/net/recv_status.hpp" +#include "coro/net/send_status.hpp" +#include "coro/net/socket.hpp" +#include "coro/task.hpp" #include -#include #include +#include namespace coro { @@ -19,7 +19,6 @@ class io_scheduler; namespace coro::net { - class udp_peer { public: @@ -37,22 +36,18 @@ public: * Creates a udp peer that can send packets but not receive them. This udp peer will not explicitly * bind to a local ip+port. */ - explicit udp_peer( - io_scheduler& scheduler, - net::domain_t domain = net::domain_t::ipv4); + explicit udp_peer(io_scheduler& scheduler, net::domain_t domain = net::domain_t::ipv4); /** * Creates a udp peer that can send and receive packets. This peer will bind to the given ip_port. */ - explicit udp_peer( - io_scheduler& scheduler, - const info& bind_info); + explicit udp_peer(io_scheduler& scheduler, const info& bind_info); udp_peer(const udp_peer&) = delete; udp_peer(udp_peer&&) = default; auto operator=(const udp_peer&) noexcept -> udp_peer& = delete; auto operator=(udp_peer&&) noexcept -> udp_peer& = default; - ~udp_peer() = default; + ~udp_peer() = default; /** * @param op The poll operation to perform on the udp socket. Note that if this is a send only @@ -60,9 +55,8 @@ public: * @param timeout The timeout for the poll operation to be ready. * @return The result status of the poll operation. */ - auto poll( - poll_op op, - std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task + auto poll(poll_op op, std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) + -> coro::task { co_return co_await m_io_scheduler.poll(m_socket, op, timeout); } @@ -74,11 +68,9 @@ public: * un-sent will correspond to bytes at the end of the given buffer. */ template - auto sendto( - const info& peer_info, - const buffer_type& buffer) -> std::pair> + auto sendto(const info& peer_info, const buffer_type& buffer) -> std::pair> { - if(buffer.empty()) + if (buffer.empty()) { return {send_status::ok, std::span{}}; } @@ -91,14 +83,9 @@ public: socklen_t peer_len{sizeof(peer)}; auto bytes_sent = ::sendto( - m_socket.native_handle(), - buffer.data(), - buffer.size(), - 0, - reinterpret_cast(&peer), - peer_len); + m_socket.native_handle(), buffer.data(), buffer.size(), 0, reinterpret_cast(&peer), peer_len); - if(bytes_sent >= 0) + if (bytes_sent >= 0) { return {send_status::ok, std::span{buffer.data() + bytes_sent, buffer.size() - bytes_sent}}; } @@ -116,27 +103,21 @@ public: * it might not fill the entire buffer. */ template - auto recvfrom( - buffer_type&& buffer) -> std::tuple> + auto recvfrom(buffer_type&& buffer) -> std::tuple> { // The user must bind locally to be able to receive packets. - if(!m_bound) + if (!m_bound) { return {recv_status::udp_not_bound, udp_peer::info{}, std::span{}}; } sockaddr_in peer{}; - socklen_t peer_len{sizeof(peer)}; + socklen_t peer_len{sizeof(peer)}; auto bytes_read = ::recvfrom( - m_socket.native_handle(), - buffer.data(), - buffer.size(), - 0, - reinterpret_cast(&peer), - &peer_len); + m_socket.native_handle(), buffer.data(), buffer.size(), 0, reinterpret_cast(&peer), &peer_len); - if(bytes_read < 0) + if (bytes_read < 0) { return {static_cast(errno), udp_peer::info{}, std::span{}}; } @@ -150,10 +131,8 @@ public: recv_status::ok, udp_peer::info{ .address = net::ip_address{ip_addr_view, static_cast(peer.sin_family)}, - .port = ntohs(peer.sin_port) - }, - std::span{buffer.data(), static_cast(bytes_read)} - }; + .port = ntohs(peer.sin_port)}, + std::span{buffer.data(), static_cast(bytes_read)}}; } private: diff --git a/inc/coro/sync_wait.hpp b/inc/coro/sync_wait.hpp index cf9b61f..46dfbeb 100644 --- a/inc/coro/sync_wait.hpp +++ b/inc/coro/sync_wait.hpp @@ -182,7 +182,9 @@ private: coroutine_type m_coroutine; }; -template::awaiter_return_type> +template< + concepts::awaitable awaitable, + typename return_type = concepts::awaitable_traits::awaiter_return_type> static auto make_sync_wait_task(awaitable&& a) -> sync_wait_task { if constexpr (std::is_void_v) diff --git a/inc/coro/task.hpp b/inc/coro/task.hpp index 75bc53a..a85e751 100644 --- a/inc/coro/task.hpp +++ b/inc/coro/task.hpp @@ -198,7 +198,7 @@ public: return false; } - auto operator co_await() const & noexcept + auto operator co_await() const& noexcept { struct awaitable : public awaitable_base { @@ -208,7 +208,7 @@ public: return awaitable{m_coroutine}; } - auto operator co_await() const && noexcept + auto operator co_await() const&& noexcept { struct awaitable : public awaitable_base { diff --git a/inc/coro/when_all.hpp b/inc/coro/when_all.hpp index 96d5f2e..5dde715 100644 --- a/inc/coro/when_all.hpp +++ b/inc/coro/when_all.hpp @@ -434,7 +434,9 @@ private: coroutine_handle_type m_coroutine; }; -template::awaiter_return_type> +template< + concepts::awaitable awaitable, + typename return_type = concepts::awaitable_traits::awaiter_return_type> static auto make_when_all_task(awaitable&& a) -> when_all_task { if constexpr (std::is_void_v) @@ -453,12 +455,14 @@ static auto make_when_all_task(awaitable&& a) -> when_all_task template [[nodiscard]] auto when_all_awaitable(awaitables_type&&... awaitables) { - return detail::when_all_ready_awaitable< - std::tuple::awaiter_return_type>...>>( + return detail::when_all_ready_awaitable::awaiter_return_type>...>>( std::make_tuple(detail::make_when_all_task(std::forward(awaitables))...)); } -template::awaiter_return_type> +template< + concepts::awaitable awaitable, + typename return_type = concepts::awaitable_traits::awaiter_return_type> [[nodiscard]] auto when_all_awaitable(std::vector& awaitables) -> detail::when_all_ready_awaitable>> { diff --git a/src/io_scheduler.cpp b/src/io_scheduler.cpp index 39c1a7e..53b3db0 100644 --- a/src/io_scheduler.cpp +++ b/src/io_scheduler.cpp @@ -4,11 +4,8 @@ namespace coro { - -detail::resume_token_base::resume_token_base(io_scheduler* s) noexcept - : m_scheduler(s), m_state(nullptr) +detail::resume_token_base::resume_token_base(io_scheduler* s) noexcept : m_scheduler(s), m_state(nullptr) { - } detail::resume_token_base::resume_token_base(resume_token_base&& other) @@ -148,10 +145,10 @@ auto io_scheduler::task_manager::make_cleanup_task(task user_task, task_po io_scheduler::io_scheduler(const options opts) : m_epoll_fd(epoll_create1(EPOLL_CLOEXEC)), - m_accept_fd(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)), - m_timer_fd(timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC)), - m_thread_strategy(opts.thread_strategy), - m_task_manager(opts.reserve_size, opts.growth_factor) + m_accept_fd(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)), + m_timer_fd(timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC)), + m_thread_strategy(opts.thread_strategy), + m_task_manager(opts.reserve_size, opts.growth_factor) { epoll_event e{}; e.events = EPOLLIN; @@ -318,10 +315,7 @@ auto io_scheduler::poll(fd_t fd, poll_op op, std::chrono::milliseconds timeout) co_return status; } -auto io_scheduler::poll( - const net::socket& sock, - poll_op op, - std::chrono::milliseconds timeout) +auto io_scheduler::poll(const net::socket& sock, poll_op op, std::chrono::milliseconds timeout) -> coro::task { return poll(sock.native_handle(), op, timeout); @@ -340,16 +334,13 @@ auto io_scheduler::read(fd_t fd, std::span buffer, std::chrono::millisecon } } -auto io_scheduler::read( - const net::socket& sock, - std::span buffer, - std::chrono::milliseconds timeout) -> coro::task> +auto io_scheduler::read(const net::socket& sock, std::span buffer, std::chrono::milliseconds timeout) + -> coro::task> { return read(sock.native_handle(), buffer, timeout); } -auto io_scheduler::write( - fd_t fd, const std::span buffer, std::chrono::milliseconds timeout) +auto io_scheduler::write(fd_t fd, const std::span buffer, std::chrono::milliseconds timeout) -> coro::task> { auto status = co_await poll(fd, poll_op::write, timeout); @@ -362,10 +353,8 @@ auto io_scheduler::write( } } -auto io_scheduler::write( - const net::socket& sock, - const std::span buffer, - std::chrono::milliseconds timeout) -> coro::task> +auto io_scheduler::write(const net::socket& sock, const std::span buffer, std::chrono::milliseconds timeout) + -> coro::task> { return write(sock.native_handle(), buffer, timeout); } @@ -430,7 +419,8 @@ auto io_scheduler::shutdown(shutdown_t wait_for_tasks) -> void } } -auto io_scheduler::make_scheduler_after_task(coro::task task, std::chrono::milliseconds wait_time) -> coro::task +auto io_scheduler::make_scheduler_after_task(coro::task task, std::chrono::milliseconds wait_time) + -> coro::task { // Wait for the period requested, and then resume their task. co_await yield_for(wait_time); diff --git a/src/mutex.cpp b/src/mutex.cpp new file mode 100644 index 0000000..9247f21 --- /dev/null +++ b/src/mutex.cpp @@ -0,0 +1,65 @@ +#include "coro/mutex.hpp" + +namespace coro +{ +auto mutex::lock() -> awaiter +{ + return awaiter(*this); +} + +auto mutex::try_lock() -> bool +{ + bool expected = false; + return m_state.compare_exchange_strong(expected, true, std::memory_order::release, std::memory_order::relaxed); +} + +auto mutex::unlock() -> void +{ + m_state.exchange(false, std::memory_order::release); + awaiter* next{nullptr}; + { + std::scoped_lock lk{m_waiter_mutex}; + if (!m_waiter_list.empty()) + { + next = m_waiter_list.front(); + m_waiter_list.pop_front(); + } + } + + if (next != nullptr) + { + next->m_awaiting_coroutine.resume(); + } +} + +auto mutex::awaiter::await_ready() const noexcept -> bool +{ + return m_mutex.try_lock(); +} + +auto mutex::awaiter::await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool +{ + m_awaiting_coroutine = awaiting_coroutine; + + { + // Its possible between await_ready() and await_suspend() the lock was released, + // if thats the case acquire it immediately. + std::scoped_lock lk{m_mutex.m_waiter_mutex}; + if (m_mutex.m_waiter_list.empty() && m_mutex.try_lock()) + { + return false; + } + + // Ok its still held, add ourself to the wiater list. + m_mutex.m_waiter_list.emplace_back(this); + } + + return true; +} + +auto mutex::awaiter::await_resume() noexcept -> scoped_lock +{ + return scoped_lock(m_mutex); +} + +} // namespace coro diff --git a/src/net/dns_resolver.cpp b/src/net/dns_resolver.cpp index 34593bf..ab8c61d 100644 --- a/src/net/dns_resolver.cpp +++ b/src/net/dns_resolver.cpp @@ -1,26 +1,20 @@ #include "coro/net/dns_resolver.hpp" +#include #include #include -#include namespace coro::net { - -uint64_t dns_resolver::m_ares_count{0}; +uint64_t dns_resolver::m_ares_count{0}; std::mutex dns_resolver::m_ares_mutex{}; -auto ares_dns_callback( - void* arg, - int status, - int /*timeouts*/, - struct hostent* host -) -> void +auto ares_dns_callback(void* arg, int status, int /*timeouts*/, struct hostent* host) -> void { auto& result = *static_cast(arg); --result.m_pending_dns_requests; - if(host == nullptr || status != ARES_SUCCESS) + if (host == nullptr || status != ARES_SUCCESS) { result.m_status = dns_status::error; } @@ -28,19 +22,18 @@ auto ares_dns_callback( { result.m_status = dns_status::complete; - for(size_t i = 0; host->h_addr_list[i] != nullptr; ++i) + for (size_t i = 0; host->h_addr_list[i] != nullptr; ++i) { - size_t len = (host->h_addrtype == AF_INET) ? net::ip_address::ipv4_len : net::ip_address::ipv6_len; + size_t len = (host->h_addrtype == AF_INET) ? net::ip_address::ipv4_len : net::ip_address::ipv6_len; net::ip_address ip_addr{ std::span{reinterpret_cast(host->h_addr_list[i]), len}, - static_cast(host->h_addrtype) - }; + static_cast(host->h_addrtype)}; result.m_ip_addresses.emplace_back(std::move(ip_addr)); } } - if(result.m_pending_dns_requests == 0) + if (result.m_pending_dns_requests == 0) { result.m_token.resume(); } @@ -50,7 +43,6 @@ dns_result::dns_result(coro::resume_token& token, uint64_t pending_dns_req : m_token(token), m_pending_dns_requests(pending_dns_requests) { - } dns_resolver::dns_resolver(io_scheduler& scheduler, std::chrono::milliseconds timeout) @@ -59,10 +51,10 @@ dns_resolver::dns_resolver(io_scheduler& scheduler, std::chrono::milliseconds ti { { std::lock_guard g{m_ares_mutex}; - if(m_ares_count == 0) + if (m_ares_count == 0) { auto ares_status = ares_library_init(ARES_LIB_INIT_ALL); - if(ares_status != ARES_SUCCESS) + if (ares_status != ARES_SUCCESS) { throw std::runtime_error{ares_strerror(ares_status)}; } @@ -71,7 +63,7 @@ dns_resolver::dns_resolver(io_scheduler& scheduler, std::chrono::milliseconds ti } auto channel_init_status = ares_init(&m_ares_channel); - if(channel_init_status != ARES_SUCCESS) + if (channel_init_status != ARES_SUCCESS) { throw std::runtime_error{ares_strerror(channel_init_status)}; } @@ -79,7 +71,7 @@ dns_resolver::dns_resolver(io_scheduler& scheduler, std::chrono::milliseconds ti dns_resolver::~dns_resolver() { - if(m_ares_channel != nullptr) + if (m_ares_channel != nullptr) { ares_destroy(m_ares_channel); m_ares_channel = nullptr; @@ -88,7 +80,7 @@ dns_resolver::~dns_resolver() { std::lock_guard g{m_ares_mutex}; --m_ares_count; - if(m_ares_count == 0) + if (m_ares_count == 0) { ares_library_cleanup(); } @@ -97,7 +89,7 @@ dns_resolver::~dns_resolver() auto dns_resolver::host_by_name(const net::hostname& hn) -> coro::task> { - auto token = m_scheduler.make_resume_token(); + auto token = m_scheduler.make_resume_token(); auto result_ptr = std::make_unique(token, 2); ares_gethostbyname(m_ares_channel, hn.data().data(), AF_INET, ares_dns_callback, result_ptr.get()); @@ -114,26 +106,26 @@ auto dns_resolver::host_by_name(const net::hostname& hn) -> coro::task void { std::array ares_sockets{}; - std::array poll_ops{}; + std::array poll_ops{}; int bitmask = ares_getsock(m_ares_channel, ares_sockets.data(), ARES_GETSOCK_MAXNUM); size_t new_sockets{0}; - for(size_t i = 0; i < ARES_GETSOCK_MAXNUM; ++i) + for (size_t i = 0; i < ARES_GETSOCK_MAXNUM; ++i) { uint64_t ops{0}; - if(ARES_GETSOCK_READABLE(bitmask, i)) + if (ARES_GETSOCK_READABLE(bitmask, i)) { ops |= static_cast(poll_op::read); } - if(ARES_GETSOCK_WRITABLE(bitmask, i)) + if (ARES_GETSOCK_WRITABLE(bitmask, i)) { ops |= static_cast(poll_op::write); } - if(ops != 0) + if (ops != 0) { poll_ops[i] = static_cast(ops); ++new_sockets; @@ -146,12 +138,12 @@ auto dns_resolver::ares_poll() -> void } } - for(size_t i = 0; i < new_sockets; ++i) + for (size_t i = 0; i < new_sockets; ++i) { io_scheduler::fd_t fd = static_cast(ares_sockets[i]); // If this socket is not currently actively polling, start polling! - if(m_active_sockets.emplace(fd).second) + if (m_active_sockets.emplace(fd).second) { m_scheduler.schedule(make_poll_task(fd, poll_ops[i])); } @@ -161,15 +153,15 @@ auto dns_resolver::ares_poll() -> void auto dns_resolver::make_poll_task(io_scheduler::fd_t fd, poll_op ops) -> coro::task { auto result = co_await m_scheduler.poll(fd, ops, m_timeout); - switch(result) + switch (result) { case poll_status::event: { - auto read_sock = poll_op_readable(ops) ? fd : ARES_SOCKET_BAD; + auto read_sock = poll_op_readable(ops) ? fd : ARES_SOCKET_BAD; auto write_sock = poll_op_writeable(ops) ? fd : ARES_SOCKET_BAD; ares_process_fd(m_ares_channel, read_sock, write_sock); } - break; + break; case poll_status::timeout: ares_process_fd(m_ares_channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD); break; diff --git a/src/net/ip_address.cpp b/src/net/ip_address.cpp index 74ea25c..78cbee7 100644 --- a/src/net/ip_address.cpp +++ b/src/net/ip_address.cpp @@ -2,16 +2,17 @@ namespace coro::net { - static std::string domain_ipv4{"ipv4"}; static std::string domain_ipv6{"ipv6"}; auto to_string(domain_t domain) -> const std::string& { - switch(domain) + switch (domain) { - case domain_t::ipv4: return domain_ipv4; - case domain_t::ipv6: return domain_ipv6; + case domain_t::ipv4: + return domain_ipv4; + case domain_t::ipv6: + return domain_ipv6; } throw std::runtime_error{"coro::net::to_string(domain_t) unknown domain"}; } diff --git a/src/net/recv_status.cpp b/src/net/recv_status.cpp index 0fa7681..ff4614b 100644 --- a/src/net/recv_status.cpp +++ b/src/net/recv_status.cpp @@ -2,7 +2,6 @@ namespace coro::net { - static const std::string recv_status_ok{"ok"}; static const std::string recv_status_closed{"closed"}; static const std::string recv_status_udp_not_bound{"udp_not_bound"}; @@ -20,21 +19,33 @@ static const std::string recv_status_unknown{"unknown"}; auto to_string(recv_status status) -> const std::string& { - switch(status) + switch (status) { - case recv_status::ok: return recv_status_ok; - case recv_status::closed: return recv_status_closed; - case recv_status::udp_not_bound: return recv_status_udp_not_bound; - //case recv_status::try_again: return recv_status_try_again; - case recv_status::would_block: return recv_status_would_block; - case recv_status::bad_file_descriptor: return recv_status_bad_file_descriptor; - case recv_status::connection_refused: return recv_status_connection_refused; - case recv_status::memory_fault: return recv_status_memory_fault; - case recv_status::interrupted: return recv_status_interrupted; - case recv_status::invalid_argument: return recv_status_invalid_argument; - case recv_status::no_memory: return recv_status_no_memory; - case recv_status::not_connected: return recv_status_not_connected; - case recv_status::not_a_socket: return recv_status_not_a_socket; + case recv_status::ok: + return recv_status_ok; + case recv_status::closed: + return recv_status_closed; + case recv_status::udp_not_bound: + return recv_status_udp_not_bound; + // case recv_status::try_again: return recv_status_try_again; + case recv_status::would_block: + return recv_status_would_block; + case recv_status::bad_file_descriptor: + return recv_status_bad_file_descriptor; + case recv_status::connection_refused: + return recv_status_connection_refused; + case recv_status::memory_fault: + return recv_status_memory_fault; + case recv_status::interrupted: + return recv_status_interrupted; + case recv_status::invalid_argument: + return recv_status_invalid_argument; + case recv_status::no_memory: + return recv_status_no_memory; + case recv_status::not_connected: + return recv_status_not_connected; + case recv_status::not_a_socket: + return recv_status_not_a_socket; } return recv_status_unknown; diff --git a/src/net/send_status.cpp b/src/net/send_status.cpp index 336e6b1..bc7b702 100644 --- a/src/net/send_status.cpp +++ b/src/net/send_status.cpp @@ -2,5 +2,4 @@ namespace coro::net { - } // namespace coro::net diff --git a/src/net/socket.cpp b/src/net/socket.cpp index 6071fda..3969818 100644 --- a/src/net/socket.cpp +++ b/src/net/socket.cpp @@ -2,7 +2,6 @@ namespace coro::net { - auto socket::type_to_os(type_t type) -> int { switch (type) @@ -96,11 +95,8 @@ auto make_socket(const socket::options& opts) -> socket return s; } -auto make_accept_socket( - const socket::options& opts, - const net::ip_address& address, - uint16_t port, - int32_t backlog) -> socket +auto make_accept_socket(const socket::options& opts, const net::ip_address& address, uint16_t port, int32_t backlog) + -> socket { socket s = make_socket(opts); @@ -120,7 +116,7 @@ auto make_accept_socket( throw std::runtime_error{"Failed to bind."}; } - if(opts.type == socket::type_t::tcp) + if (opts.type == socket::type_t::tcp) { if (listen(s.native_handle(), backlog) < 0) { diff --git a/src/net/tcp_client.cpp b/src/net/tcp_client.cpp index 6b9fe8e..e3561bd 100644 --- a/src/net/tcp_client.cpp +++ b/src/net/tcp_client.cpp @@ -10,10 +10,8 @@ using namespace std::chrono_literals; tcp_client::tcp_client(io_scheduler& scheduler, options opts) : m_io_scheduler(scheduler), m_options(std::move(opts)), - m_socket(net::make_socket(net::socket::options{ - m_options.address.domain(), - net::socket::type_t::tcp, - net::socket::blocking_t::no})) + m_socket(net::make_socket( + net::socket::options{m_options.address.domain(), net::socket::type_t::tcp, net::socket::blocking_t::no})) { } @@ -23,12 +21,11 @@ tcp_client::tcp_client(io_scheduler& scheduler, net::socket socket, options opts m_socket(std::move(socket)), m_connect_status(connect_status::connected) { - } auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task { - if(m_connect_status.has_value() && m_connect_status.value() == connect_status::connected) + if (m_connect_status.has_value() && m_connect_status.value() == connect_status::connected) { co_return m_connect_status.value(); } diff --git a/src/net/tcp_server.cpp b/src/net/tcp_server.cpp index 834769e..ef2e4d7 100644 --- a/src/net/tcp_server.cpp +++ b/src/net/tcp_server.cpp @@ -2,17 +2,15 @@ namespace coro::net { - tcp_server::tcp_server(io_scheduler& scheduler, options opts) - : m_io_scheduler(scheduler), - m_options(std::move(opts)), - m_accept_socket(net::make_accept_socket( - net::socket::options{net::domain_t::ipv4, net::socket::type_t::tcp, net::socket::blocking_t::no}, - m_options.address, - m_options.port, - m_options.backlog)) + : m_io_scheduler(scheduler), + m_options(std::move(opts)), + m_accept_socket(net::make_accept_socket( + net::socket::options{net::domain_t::ipv4, net::socket::type_t::tcp, net::socket::blocking_t::no}, + m_options.address, + m_options.port, + m_options.backlog)) { - } auto tcp_server::poll(std::chrono::milliseconds timeout) -> coro::task @@ -24,7 +22,7 @@ auto tcp_server::accept() -> coro::net::tcp_client { sockaddr_in client{}; constexpr const int len = sizeof(struct sockaddr_in); - net::socket s{::accept(m_accept_socket.native_handle(), (struct sockaddr*)&client, (socklen_t*)&len)}; + net::socket s{::accept(m_accept_socket.native_handle(), (struct sockaddr*)&client, (socklen_t*)&len)}; std::span ip_addr_view{ reinterpret_cast(&client.sin_addr.s_addr), @@ -32,13 +30,11 @@ auto tcp_server::accept() -> coro::net::tcp_client }; return tcp_client{ - m_io_scheduler, - std::move(s), - tcp_client::options{ - .address = net::ip_address{ip_addr_view, static_cast(client.sin_family)}, - .port = ntohs(client.sin_port) - } - }; + m_io_scheduler, + std::move(s), + tcp_client::options{ + .address = net::ip_address{ip_addr_view, static_cast(client.sin_family)}, + .port = ntohs(client.sin_port)}}; }; } // namespace coro::net diff --git a/src/net/udp_peer.cpp b/src/net/udp_peer.cpp index 5cad87e..02353d4 100644 --- a/src/net/udp_peer.cpp +++ b/src/net/udp_peer.cpp @@ -2,31 +2,20 @@ namespace coro::net { - -udp_peer::udp_peer( - io_scheduler& scheduler, - net::domain_t domain) - : m_io_scheduler(scheduler), - m_socket(net::make_socket( - net::socket::options{ - domain, - net::socket::type_t::udp, - net::socket::blocking_t::no})) +udp_peer::udp_peer(io_scheduler& scheduler, net::domain_t domain) + : m_io_scheduler(scheduler), + m_socket(net::make_socket(net::socket::options{domain, net::socket::type_t::udp, net::socket::blocking_t::no})) { - } -udp_peer::udp_peer( - io_scheduler& scheduler, - const info& bind_info) - : m_io_scheduler(scheduler), - m_socket(net::make_accept_socket( - net::socket::options{bind_info.address.domain(), net::socket::type_t::udp, net::socket::blocking_t::no}, - bind_info.address, - bind_info.port)), - m_bound(true) +udp_peer::udp_peer(io_scheduler& scheduler, const info& bind_info) + : m_io_scheduler(scheduler), + m_socket(net::make_accept_socket( + net::socket::options{bind_info.address.domain(), net::socket::type_t::udp, net::socket::blocking_t::no}, + bind_info.address, + bind_info.port)), + m_bound(true) { - } } // namespace coro::net diff --git a/src/thread_pool.cpp b/src/thread_pool.cpp index 58a1f9d..86b43b2 100644 --- a/src/thread_pool.cpp +++ b/src/thread_pool.cpp @@ -99,7 +99,7 @@ auto thread_pool::executor(std::stop_token stop_token, std::size_t idx) -> void if (op != nullptr && op->m_awaiting_coroutine != nullptr) { op->m_awaiting_coroutine.resume(); - m_size.fetch_sub(1, std::memory_order_relaxed); + m_size.fetch_sub(1, std::memory_order::relaxed); } else { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 9f95a14..f30b548 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -12,6 +12,7 @@ set(LIBCORO_TEST_SOURCE_FILES test_generator.cpp test_io_scheduler.cpp test_latch.cpp + test_mutex.cpp test_sync_wait.cpp test_task.cpp test_thread_pool.cpp @@ -21,10 +22,10 @@ set(LIBCORO_TEST_SOURCE_FILES add_executable(${PROJECT_NAME} main.cpp ${LIBCORO_TEST_SOURCE_FILES}) target_compile_features(${PROJECT_NAME} PUBLIC cxx_std_20) target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) -target_link_libraries(${PROJECT_NAME} PRIVATE coro) +target_link_libraries(${PROJECT_NAME} PRIVATE libcoro) target_compile_options(${PROJECT_NAME} PUBLIC -fcoroutines) -if(CORO_CODE_COVERAGE) +if(LIBCORO_CODE_COVERAGE) target_compile_options(${PROJECT_NAME} PRIVATE --coverage) target_link_libraries(${PROJECT_NAME} PRIVATE gcov) endif() @@ -35,4 +36,4 @@ elseif(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") message(FATAL_ERROR "Clang is currently not supported.") endif() -add_test(NAME coro_test COMMAND ${PROJECT_NAME}) \ No newline at end of file +add_test(NAME libcoro_tests COMMAND ${PROJECT_NAME}) \ No newline at end of file diff --git a/test/bench.cpp b/test/bench.cpp index 9c18f29..70542f1 100644 --- a/test/bench.cpp +++ b/test/bench.cpp @@ -344,9 +344,9 @@ TEST_CASE("benchmark tcp_server echo server") * will reset/trample on each other when each side of the client + server go to poll(). */ - const constexpr std::size_t connections = 64; + const constexpr std::size_t connections = 64; const constexpr std::size_t messages_per_connection = 10'000; - const constexpr std::size_t ops = connections * messages_per_connection; + const constexpr std::size_t ops = connections * messages_per_connection; const std::string msg = "im a data point in a stream of bytes"; @@ -359,13 +359,13 @@ TEST_CASE("benchmark tcp_server echo server") std::string in(64, '\0'); // Echo the messages until the socket is closed. a 'done' message arrives. - while(true) + while (true) { auto pstatus = co_await client.poll(coro::poll_op::read); REQUIRE(pstatus == coro::poll_status::event); auto [rstatus, rspan] = client.recv(in); - if(rstatus == coro::net::recv_status::closed) + if (rstatus == coro::net::recv_status::closed) { REQUIRE(rspan.empty()); break; @@ -389,7 +389,7 @@ TEST_CASE("benchmark tcp_server echo server") listening = true; uint64_t accepted{0}; - while(accepted < connections) + while (accepted < connections) { auto pstatus = co_await server.poll(); REQUIRE(pstatus == coro::poll_status::event); @@ -411,7 +411,7 @@ TEST_CASE("benchmark tcp_server echo server") auto cstatus = co_await client.connect(); REQUIRE(cstatus == coro::net::connect_status::connected); - for(size_t i = 1; i <= messages_per_connection; ++i) + for (size_t i = 1; i <= messages_per_connection; ++i) { auto [sstatus, remaining] = client.send(msg); REQUIRE(sstatus == coro::net::send_status::ok); @@ -438,13 +438,13 @@ TEST_CASE("benchmark tcp_server echo server") // The server can take a small bit of time to start up, if we don't wait for it to notify then // the first few connections can easily fail to connect causing this test to fail. - while(!listening) + while (!listening) { std::this_thread::sleep_for(std::chrono::milliseconds{1}); } // Spawn N client connections. - for(size_t i = 0; i < connections; ++i) + for (size_t i = 0; i < connections; ++i) { REQUIRE(client_scheduler.schedule(make_client_task())); } diff --git a/test/net/test_dns_resolver.cpp b/test/net/test_dns_resolver.cpp index 71683e1..919e224 100644 --- a/test/net/test_dns_resolver.cpp +++ b/test/net/test_dns_resolver.cpp @@ -7,20 +7,18 @@ TEST_CASE("dns_resolver basic") { coro::io_scheduler scheduler{ - coro::io_scheduler::options{.thread_strategy = coro::io_scheduler::thread_strategy_t::spawn} - }; + coro::io_scheduler::options{.thread_strategy = coro::io_scheduler::thread_strategy_t::spawn}}; coro::net::dns_resolver dns_resolver{scheduler, std::chrono::milliseconds{5000}}; std::atomic done{false}; - auto make_host_by_name_task = [&](coro::net::hostname hn) -> coro::task - { + auto make_host_by_name_task = [&](coro::net::hostname hn) -> coro::task { auto result_ptr = co_await std::move(dns_resolver.host_by_name(hn)); - if(result_ptr->status() == coro::net::dns_status::complete) + if (result_ptr->status() == coro::net::dns_status::complete) { - for(const auto& ip_addr : result_ptr->ip_addresses()) + for (const auto& ip_addr : result_ptr->ip_addresses()) { std::cerr << coro::net::to_string(ip_addr.domain()) << " " << ip_addr.to_string() << "\n"; } @@ -33,7 +31,7 @@ TEST_CASE("dns_resolver basic") scheduler.schedule(make_host_by_name_task(coro::net::hostname{"www.example.com"})); - while(!done) + while (!done) { std::this_thread::sleep_for(std::chrono::milliseconds{10}); } diff --git a/test/net/test_ip_address.cpp b/test/net/test_ip_address.cpp index a172df3..7312dcc 100644 --- a/test/net/test_ip_address.cpp +++ b/test/net/test_ip_address.cpp @@ -27,10 +27,12 @@ TEST_CASE("net::ip_address from_string() ipv4") TEST_CASE("net::ip_address from_string() ipv6") { { - auto ip_addr = coro::net::ip_address::from_string("0123:4567:89ab:cdef:0123:4567:89ab:cdef", coro::net::domain_t::ipv6); + auto ip_addr = + coro::net::ip_address::from_string("0123:4567:89ab:cdef:0123:4567:89ab:cdef", coro::net::domain_t::ipv6); REQUIRE(ip_addr.to_string() == "123:4567:89ab:cdef:123:4567:89ab:cdef"); REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6); - std::array expected{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}; + std::array expected{ + 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}; REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin())); } @@ -46,7 +48,8 @@ TEST_CASE("net::ip_address from_string() ipv6") auto ip_addr = coro::net::ip_address::from_string("::1", coro::net::domain_t::ipv6); REQUIRE(ip_addr.to_string() == "::1"); REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6); - std::array expected{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}; + std::array expected{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}; REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin())); } @@ -54,7 +57,8 @@ TEST_CASE("net::ip_address from_string() ipv6") auto ip_addr = coro::net::ip_address::from_string("1::1", coro::net::domain_t::ipv6); REQUIRE(ip_addr.to_string() == "1::1"); REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6); - std::array expected{0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}; + std::array expected{ + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}; REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin())); } @@ -62,7 +66,8 @@ TEST_CASE("net::ip_address from_string() ipv6") auto ip_addr = coro::net::ip_address::from_string("1::", coro::net::domain_t::ipv6); REQUIRE(ip_addr.to_string() == "1::"); REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6); - std::array expected{0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + std::array expected{ + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin())); } } diff --git a/test/net/test_tcp_server.cpp b/test/net/test_tcp_server.cpp index fa95773..83d6059 100644 --- a/test/net/test_tcp_server.cpp +++ b/test/net/test_tcp_server.cpp @@ -64,7 +64,7 @@ TEST_CASE("tcp_server ping server") scheduler.schedule(make_server_task()); scheduler.schedule(make_client_task()); - while(!scheduler.empty()) + while (!scheduler.empty()) { std::this_thread::sleep_for(std::chrono::milliseconds{1}); } diff --git a/test/net/test_udp_peers.cpp b/test/net/test_udp_peers.cpp index 622c6b1..34ad3a2 100644 --- a/test/net/test_udp_peers.cpp +++ b/test/net/test_udp_peers.cpp @@ -9,7 +9,7 @@ TEST_CASE("udp one way") coro::io_scheduler scheduler{}; auto make_send_task = [&]() -> coro::task { - coro::net::udp_peer peer{scheduler}; + coro::net::udp_peer peer{scheduler}; coro::net::udp_peer::info peer_info{}; auto [sstatus, remaining] = peer.sendto(peer_info, msg); @@ -20,9 +20,7 @@ TEST_CASE("udp one way") }; auto make_recv_task = [&]() -> coro::task { - coro::net::udp_peer::info self_info{ - .address = coro::net::ip_address::from_string("0.0.0.0") - }; + coro::net::udp_peer::info self_info{.address = coro::net::ip_address::from_string("0.0.0.0")}; coro::net::udp_peer self{scheduler, self_info}; @@ -53,18 +51,18 @@ TEST_CASE("udp echo peers") coro::io_scheduler scheduler{}; auto make_peer_task = [&scheduler]( - uint16_t my_port, - uint16_t peer_port, - bool send_first, - const std::string my_msg, - const std::string peer_msg) -> coro::task { - + uint16_t my_port, + uint16_t peer_port, + bool send_first, + const std::string my_msg, + const std::string peer_msg) -> coro::task { coro::net::udp_peer::info my_info{.address = coro::net::ip_address::from_string("0.0.0.0"), .port = my_port}; - coro::net::udp_peer::info peer_info{.address = coro::net::ip_address::from_string("127.0.0.1"), .port = peer_port}; + coro::net::udp_peer::info peer_info{ + .address = coro::net::ip_address::from_string("127.0.0.1"), .port = peer_port}; coro::net::udp_peer me{scheduler, my_info}; - if(send_first) + if (send_first) { // Send my message to my peer first. auto [sstatus, remaining] = me.sendto(peer_info, my_msg); @@ -86,7 +84,7 @@ TEST_CASE("udp echo peers") REQUIRE(buffer == peer_msg); } - if(send_first) + if (send_first) { // I sent first so now I need to await my peer's message. auto pstatus = co_await me.poll(coro::poll_op::read); @@ -111,5 +109,5 @@ TEST_CASE("udp echo peers") }; scheduler.schedule(make_peer_task(8081, 8080, false, peer2_msg, peer1_msg)); - scheduler.schedule(make_peer_task(8080, 8081, true, peer1_msg, peer2_msg)); + scheduler.schedule(make_peer_task(8080, 8081, true, peer1_msg, peer2_msg)); } diff --git a/test/test_io_scheduler.cpp b/test/test_io_scheduler.cpp index c3b6f7c..5de9f1b 100644 --- a/test/test_io_scheduler.cpp +++ b/test/test_io_scheduler.cpp @@ -16,8 +16,8 @@ TEST_CASE("io_scheduler sizeof()") std::cerr << "sizeof(coro:task)=[" << sizeof(coro::task) << "]\n"; std::cerr << "sizeof(std::coroutine_handle<>)=[" << sizeof(std::coroutine_handle<>) << "]\n"; - std::cerr << "sizeof(std::variant, std::coroutine_handle<>>)=[" << sizeof(std::variant, std::coroutine_handle<>>) - << "]\n"; + std::cerr << "sizeof(std::variant, std::coroutine_handle<>>)=[" + << sizeof(std::variant, std::coroutine_handle<>>) << "]\n"; REQUIRE(true); } diff --git a/test/test_mutex.cpp b/test/test_mutex.cpp new file mode 100644 index 0000000..e24a654 --- /dev/null +++ b/test/test_mutex.cpp @@ -0,0 +1,82 @@ +#include "catch.hpp" + +#include + +#include +#include + +TEST_CASE("mutex single waiter not locked") +{ + std::vector output; + + coro::mutex m; + + auto make_emplace_task = [&](coro::mutex& m) -> coro::task { + std::cerr << "Acquiring lock\n"; + auto scoped_lock = co_await m.lock(); + std::cerr << "lock acquired, emplacing back 1\n"; + output.emplace_back(1); + std::cerr << "coroutine done\n"; + co_return; + }; + + coro::sync_wait(make_emplace_task(m)); + + REQUIRE(m.try_lock()); + m.unlock(); + + REQUIRE(output.size() == 1); + REQUIRE(output[0] == 1); +} + +TEST_CASE("mutex many waiters until event") +{ + std::atomic value{0}; + std::vector> tasks; + + coro::thread_pool tp{coro::thread_pool::options{.thread_count = 1}}; + + coro::mutex m; // acquires and holds the lock until the event is triggered + coro::event e; // triggers the blocking thread to release the lock + + auto make_task = [&](uint64_t id) -> coro::task { + co_await tp.schedule().value(); + std::cerr << "id = " << id << " waiting to acquire the lock\n"; + auto scoped_lock = co_await m.lock(); + std::cerr << "id = " << id << " lock acquired\n"; + value.fetch_add(1, std::memory_order::relaxed); + std::cerr << "id = " << id << " coroutine done\n"; + co_return; + }; + + auto make_block_task = [&]() -> coro::task { + co_await tp.schedule().value(); + std::cerr << "block task acquiring lock\n"; + auto scoped_lock = co_await m.lock(); + std::cerr << "block task acquired lock, waiting on event\n"; + co_await e; + co_return; + }; + + auto make_set_task = [&]() -> coro::task { + co_await tp.schedule().value(); + std::cerr << "set task setting event\n"; + e.set(); + co_return; + }; + + // Grab mutex so all threads block. + tasks.emplace_back(make_block_task()); + + // Create N tasks that attempt to lock the mutex. + for (uint64_t i = 1; i <= 4; ++i) + { + tasks.emplace_back(make_task(i)); + } + + tasks.emplace_back(make_set_task()); + + coro::sync_wait(coro::when_all_awaitable(tasks)); + + REQUIRE(value == 4); +} \ No newline at end of file diff --git a/test/test_thread_pool.cpp b/test/test_thread_pool.cpp index 8e8afde..e23b72f 100644 --- a/test/test_thread_pool.cpp +++ b/test/test_thread_pool.cpp @@ -157,3 +157,36 @@ TEST_CASE("thread_pool schedule functor return_type = void") REQUIRE_THROWS(coro::sync_wait(tp.schedule(f, std::ref(counter)))); } + +TEST_CASE("thread_pool event jump threads") +{ + // This test verifies that the thread that sets the event ends up executing every waiter on the event + + coro::thread_pool tp1{coro::thread_pool::options{.thread_count = 1}}; + coro::thread_pool tp2{coro::thread_pool::options{.thread_count = 1}}; + + coro::event e{}; + + auto make_tp1_task = [&]() -> coro::task { + co_await tp1.schedule().value(); + auto before_thread_id = std::this_thread::get_id(); + std::cerr << "before event thread_id = " << before_thread_id << "\n"; + co_await e; + auto after_thread_id = std::this_thread::get_id(); + std::cerr << "after event thread_id = " << after_thread_id << "\n"; + + REQUIRE(before_thread_id != after_thread_id); + + co_return; + }; + + auto make_tp2_task = [&]() -> coro::task { + co_await tp2.schedule().value(); + std::this_thread::sleep_for(std::chrono::milliseconds{10}); + std::cerr << "setting event\n"; + e.set(); + co_return; + }; + + coro::sync_wait(coro::when_all_awaitable(make_tp1_task(), make_tp2_task())); +} \ No newline at end of file