diff --git a/AsyncManager.cpp b/AsyncManager.cpp new file mode 100644 index 0000000..26de72f --- /dev/null +++ b/AsyncManager.cpp @@ -0,0 +1,114 @@ +#include "AsyncManager.hpp" +#include "Runtime.hpp" + +#include +#include +#ifndef PLATFORM_WINDOWS +# include +#else +# include +#endif + + + +void AsyncManager::cleanFutureMap(SockFutureMap& map) { + std::vector erasureQueue; + for (auto it = map.begin(); it != map.end(); it++) { + if (!it->second) [[unlikely]] { + erasureQueue.push_back(it); + } + } + for (auto& it : erasureQueue) { + map.erase(it); + } +} + +void AsyncManager::run() { + while (!stopping && runtime.cooperate()) { + // We should stop once there is nothihng left to do + if (sockReads.empty() && sockWrites.empty()) [[unlikely]] { + break; + } + + // We need to keep track of the highest fd for socket() + int maxFd = 0; + + // Create except FD set + fd_set exceptFds; + FD_ZERO(&exceptFds); + + // Create write FD set + fd_set writeFds; + FD_ZERO(&writeFds); + for (const auto& [fd, cb] : sockWrites) { + FD_SET(fd, &writeFds); + FD_SET(fd, &exceptFds); + if (fd > maxFd) { + maxFd = fd; + } + } + + // Create read FD set + fd_set readFds; + FD_ZERO(&readFds); + for (const auto& [fd, cb] : sockReads) { + FD_SET(fd, &readFds); + FD_SET(fd, &exceptFds); + if (fd > maxFd) { + maxFd = fd; + } + } + + // Specify timeout + timeval tv{ + .tv_sec = 0, + .tv_usec = 250000 + }; + + // select() until there is data + bool error = false; + if (select(maxFd+1, &readFds, &writeFds, &exceptFds, &tv) < 0) { + FD_ZERO(&readFds); + FD_ZERO(&writeFds); + error = true; + } + + // Execution queue + std::vector> execQueue; + + // Collect all write futures + for (auto& [fd, future] : sockWrites) { + if (FD_ISSET(fd, &writeFds)) { + // Socket is ready for writing + execQueue.push_back({future, false}); + } + if (FD_ISSET(fd, &exceptFds) || error) [[unlikely]] { + // An exception happened in the socket + execQueue.push_back({future, true}); + } + } + + // Collect all read futures + for (auto& [fd, future] : sockReads) { + if (FD_ISSET(fd, &readFds)) { + // Socket is ready for reading + execQueue.push_back({future, false}); + } + if (FD_ISSET(fd, &exceptFds) || error) [[unlikely]] { + // An exception happened in the socket + execQueue.push_back({future, true}); + } + } + + // Set futures + for (auto& [future, value] : execQueue) { + future->set(value); + future = nullptr; + } + + // Clean future maps + cleanFutureMap(sockWrites); + cleanFutureMap(sockReads); + } + stopping = false; +} diff --git a/AsyncManager.hpp b/AsyncManager.hpp new file mode 100644 index 0000000..94e7cec --- /dev/null +++ b/AsyncManager.hpp @@ -0,0 +1,59 @@ +#ifndef _ASYNCMANAGER_HPP +#define _ASYNCMANAGER_HPP +#include "Runtime.hpp" + +#include +#include +#include "basic-coro/AwaitableTask.hpp" +#include "basic-coro/SingleEvent.hpp" + +class Runtime; + + + +class AsyncManager { +public: + using SockError = bool; + using SockFuture = basiccoro::SingleEvent; + using SockFutureUnique = std::unique_ptr; + using SockFutureMap = std::unordered_multimap; + +private: + Runtime& runtime; + + SockFutureMap sockReads; + SockFutureMap sockWrites; + bool stopping = false; + + static + void cleanFutureMap(SockFutureMap&); + +public: + AsyncManager(Runtime& runtime) : runtime(runtime) {} + AsyncManager(AsyncManager&) = delete; + AsyncManager(const AsyncManager&) = delete; + AsyncManager(AsyncManager&&) = delete; + + void run(); + void stop() { + stopping = true; + } + + basiccoro::AwaitableTask waitRead(int fd) { + auto event = std::make_unique(); + auto eventPtr = event.get(); + sockReads.emplace(fd, std::move(event)); + co_return co_await *eventPtr; + } + basiccoro::AwaitableTask waitWrite(int fd) { + auto event = std::make_unique(); + auto eventPtr = event.get(); + sockWrites.emplace(fd, std::move(event)); + co_return co_await *eventPtr; + } + + auto& getRuntime() const { + return runtime; + } +}; +#endif diff --git a/CMakeLists.txt b/CMakeLists.txt index 298cbd9..ff24ae8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.5) project(llama.any LANGUAGES CXX) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) add_executable(llama @@ -11,7 +11,9 @@ add_executable(llama Socket.hpp Receiver.hpp Receiver.cpp Sender.hpp Sender.cpp + AsyncManager.hpp AsyncManager.cpp Runtime.cpp Runtime.hpp + basic-coro/AwaitableTask.hpp basic-coro/SingleEvent.hpp basic-coro/SingleEvent.cpp ) target_compile_definitions(llama PUBLIC PLATFORM="${CMAKE_SYSTEM_NAME}") diff --git a/Client.cpp b/Client.cpp index c0a42b4..324a3a8 100644 --- a/Client.cpp +++ b/Client.cpp @@ -45,9 +45,9 @@ void Client::fetchAddr(const std::string& addr, unsigned port) { } } -Client::Client(const std::string& addr, unsigned port) { +Client::Client(const std::string& addr, unsigned port, AsyncManager& asyncManager) : aMan(asyncManager) { // Create socket - connection = std::make_unique>(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); //TODO: Care about IPv6 + connection = std::make_unique>(aMan, Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); //TODO: Care about IPv6 if (*connection < 0) [[unlikely]] { throw Exception("Failed to create TCP socket"); } @@ -72,25 +72,25 @@ Client::Client(const std::string& addr, unsigned port) { # endif } -void Client::ask(std::string_view prompt, const std::function& on_progress, const std::function& on_token) { +basiccoro::AwaitableTask Client::ask(std::string_view prompt, const std::function (unsigned progress)>& on_progress, const std::function (std::string_view token)>& on_token) { std::string fres; // Send prompt length uint8_t len = prompt.length(); - connection->writeObject(len, true); + co_await connection->writeObject(len, true); // Send prompt - connection->write(prompt); + co_await connection->write(prompt); // Receive progress for (;;) { uint8_t progress; // Receive percentage - connection->readObject(progress); + co_await connection->readObject(progress); // Run on_progress callback - on_progress(progress); + co_await on_progress(progress); // Stop at 100% if (progress == 100) break; @@ -99,15 +99,15 @@ void Client::ask(std::string_view prompt, const std::functionreadObject(len); + co_await connection->readObject(len); // End if zero if (len == 0xFF) break; // Receive response - const auto token = connection->read(len); + const auto token = co_await connection->read(len); // Run on_token callback - on_token(token); + co_await on_token(token); } } diff --git a/Client.hpp b/Client.hpp index 1532a4f..1feafe6 100644 --- a/Client.hpp +++ b/Client.hpp @@ -1,9 +1,11 @@ #ifndef CLIENT_HPP #define CLIENT_HPP #include "Runtime.hpp" +#include "AsyncManager.hpp" #include "Socket.hpp" #include "Sender.hpp" #include "Receiver.hpp" +#include "basic-coro/AwaitableTask.hpp" #include #include @@ -23,6 +25,8 @@ class Client using std::runtime_error::runtime_error; }; + AsyncManager& aMan; + int fd = -1; # ifdef HAS_ADDRINFO addrinfo @@ -37,9 +41,9 @@ class Client void fetchAddr(const std::string& addr, unsigned port); public: - Client(const std::string &addr, unsigned port); + Client(const std::string &addr, unsigned port, AsyncManager& asyncManager); - void ask(std::string_view prompt, const std::function& on_progress, const std::function& on_token); + basiccoro::AwaitableTask ask(std::string_view prompt, const std::function (unsigned progress)>& on_progress, const std::function (std::string_view token)>& on_token); }; #endif // CLIENT_HPP diff --git a/Receiver.cpp b/Receiver.cpp index 85ac957..30a08c0 100644 --- a/Receiver.cpp +++ b/Receiver.cpp @@ -4,40 +4,63 @@ #include #ifndef PLATFORM_WINDOWS # include -# include #else # include #endif -std::string Receiver::Simple::read(size_t amount) { +basiccoro::AwaitableTask Receiver::Simple::read(size_t amount) { // Create buffer std::string fres; fres.resize(amount); // Read into buffer - read(reinterpret_cast(fres.data()), fres.size()); + co_await read(reinterpret_cast(fres.data()), fres.size()); // Return final buffer - return fres; + co_return fres; } -void Receiver::Simple::read(std::byte *buffer, size_t size) { - recv(fd, reinterpret_cast(buffer), size, MSG_WAITALL); +basiccoro::AwaitableTask Receiver::Simple::read(std::byte *buffer, size_t size) { + size_t allBytesRead = 0; + + while (allBytesRead != size) { + // Wait for data + if (co_await aMan.waitRead(fd)) [[unlikely]] { + // Error + co_return true; + } + + // Receive data + ssize_t bytesRead; + if ((bytesRead = recv(fd, reinterpret_cast(buffer+allBytesRead), size-allBytesRead, 0)) < 0) [[unlikely]] { + // Error + co_return true; + } + allBytesRead += bytesRead; + } + + // No error + co_return false; } -std::string Receiver::Simple::readSome(size_t max) { +basiccoro::AwaitableTask Receiver::Simple::readSome(size_t max) { // Create buffer std::string fres; fres.resize(max); + // Wait for data + if (co_await aMan.waitRead(fd)) [[unlikely]] { + co_return ""; + } + // Receive data ssize_t bytesRead; - if ((bytesRead = recv(fd, fres.data(), max, MSG_WAITALL)) < 0) [[unlikely]] { - return ""; + if ((bytesRead = recv(fd, fres.data(), max, 0)) < 0) [[unlikely]] { + co_return ""; } // Resize and return final buffer fres.resize(bytesRead); - return fres; + co_return fres; } diff --git a/Receiver.hpp b/Receiver.hpp index 85ab5dc..6d8060f 100644 --- a/Receiver.hpp +++ b/Receiver.hpp @@ -1,24 +1,29 @@ #ifndef _RECEIVER_HPP #define _RECEIVER_HPP #include "Runtime.hpp" +#include "AsyncManager.hpp" #include #include +#include "basic-coro/AwaitableTask.hpp" namespace Receiver { class Simple { + Runtime& runtime; + AsyncManager &aMan; + protected: int fd; public: - Simple(int fd) : fd(fd) {} + Simple(AsyncManager& asyncManager, int fd) : runtime(asyncManager.getRuntime()), aMan(asyncManager), fd(fd) {} // Reads the exact amount of bytes given - std::string read(size_t amount); - void read(std::byte *buffer, size_t size); + basiccoro::AwaitableTask read(size_t amount); + basiccoro::AwaitableTask read(std::byte *buffer, size_t size); // Reads at max. the amount of bytes given - std::string readSome(size_t max); + basiccoro::AwaitableTask readSome(size_t max); // Reads an object of type T template diff --git a/Sender.cpp b/Sender.cpp index 23545bc..5327120 100644 --- a/Sender.cpp +++ b/Sender.cpp @@ -1,9 +1,9 @@ #include "Runtime.hpp" #include "Sender.hpp" +#include "basic-coro/AwaitableTask.hpp" #include #ifndef PLATFORM_WINDOWS -# include # include #else # include @@ -12,13 +12,18 @@ -void Sender::Simple::write(std::string_view str, bool moreData) { - this->write(reinterpret_cast(str.data()), str.size(), moreData); +basiccoro::AwaitableTask Sender::Simple::write(std::string_view str, bool moreData) { + co_return co_await this->write(reinterpret_cast(str.data()), str.size(), moreData); } -void Sender::Simple::write(const std::byte *data, size_t size, bool moreData) { +basiccoro::AwaitableTask Sender::Simple::write(const std::byte *data, size_t size, bool moreData) { std::string fres; + // Wait for socket to get ready for writing + if (co_await aMan.waitWrite(fd)) [[unlikely]] { + co_return true; + } + // Write - send(fd, reinterpret_cast(data), size, MSG_FLAGS_OR_ZERO(MSG_WAITALL | MSG_NOSIGNAL | (int(moreData)*MSG_MORE))); + co_return send(fd, reinterpret_cast(data), size, MSG_FLAGS_OR_ZERO(MSG_NOSIGNAL | (int(moreData)*MSG_MORE))) < 0; } diff --git a/Sender.hpp b/Sender.hpp index e829647..404606a 100644 --- a/Sender.hpp +++ b/Sender.hpp @@ -1,20 +1,25 @@ #ifndef _SENDER_HPP #define _SENDER_HPP +#include "AsyncManager.hpp" + #include #include #include +#include "basic-coro/AwaitableTask.hpp" namespace Sender { class Simple { + AsyncManager &aMan; + protected: int fd; public: - Simple(int fd) : fd(fd) {} + Simple(AsyncManager& asyncManager, int fd) : aMan(asyncManager), fd(fd) {} - void write(std::string_view, bool moreData = false); - void write(const std::byte *data, size_t, bool moreData = false); + basiccoro::AwaitableTask write(std::string_view, bool moreData = false); + basiccoro::AwaitableTask write(const std::byte *data, size_t, bool moreData = false); template auto writeObject(const T& o, bool moreData = false) { diff --git a/Socket.hpp b/Socket.hpp index c05e597..648751f 100644 --- a/Socket.hpp +++ b/Socket.hpp @@ -1,14 +1,14 @@ #ifndef _SOCKET_HPP #define _SOCKET_HPP +#include "AsyncManager.hpp" + #include #include -#if defined(PLATFORM_WINDOWS) -# include -#elif defined(PLATFORM_WII) -#include -#else +#ifndef PLATFORM_WINDOWS # include # include +#else +# include #endif @@ -58,9 +58,9 @@ public: template class SocketConnection : public SenderT, public ReceiverT, public Socket { public: - SocketConnection(Socket&& socket) + SocketConnection(AsyncManager& asyncManager, Socket&& socket) // Double-initialization seems to yield better assembly - : SenderT(socket), ReceiverT(socket), Socket(std::move(socket)) { + : SenderT(asyncManager, socket), ReceiverT(asyncManager, socket), Socket(std::move(socket)) { SenderT::fd = get(); ReceiverT::fd = get(); } diff --git a/basic-coro/AwaitableTask.hpp b/basic-coro/AwaitableTask.hpp new file mode 100644 index 0000000..04130b2 --- /dev/null +++ b/basic-coro/AwaitableTask.hpp @@ -0,0 +1,195 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace basiccoro +{ +namespace detail +{ + +template +struct PromiseBase +{ + auto get_return_object() { return std::coroutine_handle::from_promise(static_cast(*this)); } + void unhandled_exception() { std::terminate(); } +}; + +template requires std::movable || std::same_as +struct ValuePromise : public PromiseBase +{ + using value_type = T; + T val; + void return_value(T t) { val = std::move(t); } +}; + +template +struct ValuePromise : public PromiseBase +{ + using value_type = void; + void return_void() {} +}; + +template +class AwaitablePromise : public ValuePromise, T> +{ +public: + auto initial_suspend() { return std::suspend_never(); } + + auto final_suspend() noexcept + { + if (waiting_) + { + waiting_.resume(); + if (waiting_.done()) + { + waiting_.destroy(); + } + waiting_ = nullptr; + } + + return std::suspend_always(); + } + + void storeWaiting(std::coroutine_handle<> handle) + { + if (waiting_) + { + throw std::runtime_error("AwaitablePromise::storeWaiting(): already waiting"); + } + + waiting_ = handle; + } + + ~AwaitablePromise() + { + if (waiting_) + { + waiting_.destroy(); + } + } + +private: + std::coroutine_handle<> waiting_ = nullptr; +}; + +template +class TaskBase +{ +public: + using promise_type = Promise; + + TaskBase(); + TaskBase(std::coroutine_handle handle); + TaskBase(const TaskBase&) = delete; + TaskBase(TaskBase&&); + TaskBase& operator=(const TaskBase&) = delete; + TaskBase& operator=(TaskBase&&); + ~TaskBase(); + + bool done() const { return handle_.done(); } + +protected: + std::coroutine_handle handle_; + bool handleShouldBeDestroyed_; +}; + +template +TaskBase::TaskBase() + : handle_(nullptr), handleShouldBeDestroyed_(false) +{} + +template +TaskBase::TaskBase(std::coroutine_handle handle) + : handle_(handle) +{ + // TODO: this whole system needs revamping with something like UniqueCoroutineHandle + // and custom static interface to awaiter types - so await_suspend method would take in UniqueCoroutineHandle + + if (handle.done()) + { + // it is resonable to expect that if the coroutine is done before + // the task creation, then the original stack is continued without suspending, + // and coroutine needs to be destroyed with TaskBase object + handleShouldBeDestroyed_ = true; + } + else + { + // otherwise the coroutine should be managed by object that it is awaiting + handleShouldBeDestroyed_ = false; + } +} + +template +TaskBase::TaskBase(TaskBase&& other) + : handle_(other.handle_), handleShouldBeDestroyed_(std::exchange(other.handleShouldBeDestroyed_, false)) +{ +} + +template +TaskBase& TaskBase::operator=(TaskBase&& other) +{ + handle_ = other.handle_; + handleShouldBeDestroyed_ = std::exchange(other.handleShouldBeDestroyed_, false); + return *this; +} + +template +TaskBase::~TaskBase() +{ + if (handleShouldBeDestroyed_) + { + handle_.destroy(); + } +} + +} // namespace detail + +template +class AwaitableTask : public detail::TaskBase> +{ + using Base = detail::TaskBase>; + +public: + using Base::Base; + + class awaiter; + friend class awaiter; + awaiter operator co_await() const; +}; + +template +struct AwaitableTask::awaiter +{ + bool await_ready() + { + return task_.done(); + } + + template + void await_suspend(std::coroutine_handle handle) + { + task_.handle_.promise().storeWaiting(handle); + } + + T await_resume() + { + if constexpr (!std::is_same_v) + { + return std::move(task_.handle_.promise().val); + } + } + + const AwaitableTask& task_; +}; + +template +typename AwaitableTask::awaiter AwaitableTask::operator co_await() const +{ + return awaiter{*this}; +} + +} // namespace basiccoro diff --git a/basic-coro/LICENSE b/basic-coro/LICENSE new file mode 100644 index 0000000..c06a664 --- /dev/null +++ b/basic-coro/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Maksymilian Kadukowski + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/basic-coro/README.md b/basic-coro/README.md new file mode 100644 index 0000000..e085941 --- /dev/null +++ b/basic-coro/README.md @@ -0,0 +1,68 @@ +# basic-coro - c++ coroutine library +Library that implements helper types for using c++ coroutines. Please be aware that this is a training project for me - I wanted to learn more about CMake, gtest and git submodules. + +## Usage +### Prerequisites +* g++-10 + +### Installing +``` +mkdir build && cd build +cmake -D CMAKE_CXX_COMPILER=g++-10 .. +make install +``` +This will install appropriate headers into `./include/` and static linked library into `./lib/`. + +### Classes +Library includes following classes: +* `SingleEvent` which models `co_await` enabled event that can be set, +* `AwaitableTask` which models `co_await` enabled task. + +Please note that these classes are not multithreading enabled. There is no synchronization or any kind of protection form race conditions. If you need to use coroutines with multithreading, just use [CppCoro](https://github.com/lewissbaker/cppcoro). This library is mostly thought for use with simple GUI programming. + +### Example +```c++ +#include + +#include +#include + +basiccoro::AwaitableTask consumer(basiccoro::SingleEvent& event) +{ + std::cout << "consumer: start waiting" << std::endl; + + while (true) + { + const auto i = co_await event; + std::cout << "consumer: received: " << i << std::endl; + } +} + +int main() +{ + basiccoro::SingleEvent event; + consumer(event); + + while (true) + { + int i = 0; + + std::cout << "Enter no.(1-9): "; + std::cin >> i; + + if (i == 0) + { + break; + } + else if (1 <= i && i <= 9) + { + event.set(i); + } + } +} +``` +Simple example highlighting use of coroutines in producer-consumer problem. + +## Acknowledgments +* [CMake C++ Project Template](https://github.com/kigster/cmake-project-template) as this project is based on this template +* Lewis Baker has excellent [articles](https://lewissbaker.github.io/) on topic of coroutines and assymetric transfer. This project is mostly based on information (and code snippets) contained in those articles. diff --git a/basic-coro/SingleEvent.cpp b/basic-coro/SingleEvent.cpp new file mode 100644 index 0000000..8a63ba8 --- /dev/null +++ b/basic-coro/SingleEvent.cpp @@ -0,0 +1,58 @@ +#include "SingleEvent.hpp" + +namespace basiccoro +{ + +detail::SingleEventBase::SingleEventBase(detail::SingleEventBase&& other) + : waiting_(std::move(other.waiting_)) + , isSet_(std::exchange(other.isSet_, false)) +{ +} + +detail::SingleEventBase& detail::SingleEventBase::operator=(detail::SingleEventBase&& other) +{ + waiting_ = std::move(other.waiting_); + isSet_ = std::exchange(other.isSet_, false); + return *this; +} + +detail::SingleEventBase::~SingleEventBase() +{ + for (auto handle : waiting_) + { + handle.destroy(); + } +} + +void detail::SingleEventBase::set_common() +{ + if (!isSet_) + { + if (waiting_.empty()) + { + isSet_ = true; + } + else + { + // resuming coroutines can result in + // consecutive co_awaits on this object + auto temp = std::move(waiting_); + + for (auto handle : temp) + { + handle.resume(); + if (handle.done()) + { + handle.destroy(); + } + } + } + } +} + +SingleEvent::awaiter SingleEvent::operator co_await() +{ + return awaiter{*this}; +} + +} // namespace basiccoro diff --git a/basic-coro/SingleEvent.hpp b/basic-coro/SingleEvent.hpp new file mode 100644 index 0000000..dd4b4e3 --- /dev/null +++ b/basic-coro/SingleEvent.hpp @@ -0,0 +1,109 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace basiccoro +{ +namespace detail +{ + +template +class AwaiterBase +{ +public: + AwaiterBase(Event& event) + : event_(event) + {} + + bool await_ready() + { + if (event_.isSet()) + { + // unset already set event, then continue coroutine + event_.isSet_ = false; + return true; + } + + return false; + } + + void await_suspend(std::coroutine_handle<> handle) + { + event_.waiting_.push_back(handle); + } + + typename Event::value_type await_resume() + { + if constexpr (!std::is_same_v) + { + if (!event_.result) + { + throw std::runtime_error("AwaiterBase: no value in event_.result"); + } + return *event_.result; + } + } + +private: + Event& event_; +}; + +class SingleEventBase +{ +public: + SingleEventBase() = default; + SingleEventBase(const SingleEventBase&) = delete; + SingleEventBase(SingleEventBase&&); + SingleEventBase& operator=(const SingleEventBase&) = delete; + SingleEventBase& operator=(SingleEventBase&&); + ~SingleEventBase(); + + bool isSet() const { + return isSet_; + } + +protected: + void set_common(); + +private: + template + friend class AwaiterBase; + std::vector> waiting_; + bool isSet_ = false; +}; + +} // namespace detail + +template +class SingleEvent : public detail::SingleEventBase +{ +public: + using value_type = T; + using awaiter = detail::AwaiterBase>; + + void set(T t) { result = std::move(t); set_common(); } + awaiter operator co_await() { return awaiter{*this}; } + +private: + friend awaiter; + std::optional result; +}; + +template<> +class SingleEvent : public detail::SingleEventBase +{ +public: + using value_type = void; + using awaiter = detail::AwaiterBase>; + + void set() { + set_common(); + } + awaiter operator co_await(); +}; + +} // namespace basiccoro diff --git a/main.cpp b/main.cpp index 93a8023..7769191 100644 --- a/main.cpp +++ b/main.cpp @@ -1,5 +1,7 @@ #include "Runtime.hpp" +#include "AsyncManager.hpp" #include "Client.hpp" +#include "basic-coro/AwaitableTask.hpp" #include #include @@ -11,19 +13,12 @@ void on_progress(float progress) { std::cout << unsigned(progress) << '\r' << std::flush; } -int main() -{ - Runtime rt; - - // Print header - std::cout << "llama.any running on " PLATFORM ".\n" - "\n"; - +basiccoro::AwaitableTask async_main(Runtime& rt, AsyncManager &aMan) { // Ask for server address const std::string addr = rt.readInput("Server address"); // Create client - Client client(addr, 99181); + Client client(addr, 99181, aMan); // Connection loop for (;; rt.cooperate()) { @@ -37,15 +32,31 @@ int main() std::cout << "Prompt: " << prompt << std::endl; // Run inference - client.ask(prompt, [&rt] (float progress) { + co_await client.ask(prompt, [&rt] (float progress) -> basiccoro::AwaitableTask { std::cout << unsigned(progress) << "%\r" << std::flush; - rt.cooperate(); - }, [&rt] (std::string_view token) { + co_return; + }, [&rt] (std::string_view token) -> basiccoro::AwaitableTask { std::cout << token << std::flush; - rt.cooperate(); + co_return; }); std::cout << "\n"; } +} + +int main() +{ + Runtime rt; + AsyncManager aMan(rt); + + // Start async main() + async_main(rt, aMan); + + // Print header + std::cout << "llama.any running on " PLATFORM ".\n" + "\n"; + + // Start async manager + aMan.run(); return 0; }