diff --git a/AsyncManager.cpp b/AsyncManager.cpp index 0b7e269..76ad2fc 100644 --- a/AsyncManager.cpp +++ b/AsyncManager.cpp @@ -102,7 +102,7 @@ void AsyncManager::run() { // Set futures for (auto& [future, value] : execQueue) { - future->set(value); + future->set(value?Result::Error:Result::Success); future = nullptr; } diff --git a/AsyncManager.hpp b/AsyncManager.hpp index 94e7cec..c875217 100644 --- a/AsyncManager.hpp +++ b/AsyncManager.hpp @@ -11,10 +11,14 @@ class Runtime; +enum class Result { + Error, + Success +}; + class AsyncManager { public: - using SockError = bool; - using SockFuture = basiccoro::SingleEvent; + using SockFuture = basiccoro::SingleEvent; using SockFutureUnique = std::unique_ptr; using SockFutureMap = std::unordered_multimap; @@ -39,13 +43,13 @@ public: stopping = true; } - basiccoro::AwaitableTask waitRead(int fd) { + basiccoro::AwaitableTask waitRead(int fd) { auto event = std::make_unique(); auto eventPtr = event.get(); sockReads.emplace(fd, std::move(event)); co_return co_await *eventPtr; } - basiccoro::AwaitableTask waitWrite(int fd) { + basiccoro::AwaitableTask waitWrite(int fd) { auto event = std::make_unique(); auto eventPtr = event.get(); sockWrites.emplace(fd, std::move(event)); diff --git a/Client.cpp b/Client.cpp index 324a3a8..aecc907 100644 --- a/Client.cpp +++ b/Client.cpp @@ -72,22 +72,28 @@ Client::Client(const std::string& addr, unsigned port, AsyncManager& asyncManage # endif } -basiccoro::AwaitableTask Client::ask(std::string_view prompt, const std::function (unsigned progress)>& on_progress, const std::function (std::string_view token)>& on_token) { +basiccoro::AwaitableTask Client::ask(std::string_view prompt, const std::function (unsigned progress)>& on_progress, const std::function (std::string_view token)>& on_token) { std::string fres; // Send prompt length uint8_t len = prompt.length(); - co_await connection->writeObject(len, true); + if (co_await connection->writeObject(len, true) == Result::Error) { + co_return Result::Error; + } // Send prompt - co_await connection->write(prompt); + if (co_await connection->write(prompt) == Result::Error) { + co_return Result::Error; + } // Receive progress for (;;) { uint8_t progress; // Receive percentage - co_await connection->readObject(progress); + if (co_await connection->readObject(progress) == Result::Error) { + co_return Result::Error; + } // Run on_progress callback co_await on_progress(progress); @@ -99,15 +105,26 @@ basiccoro::AwaitableTask Client::ask(std::string_view prompt, const std::f // Receive response for (;;) { // Receive response length - co_await connection->readObject(len); + if (co_await connection->readObject(len) == Result::Error) { + co_return Result::Error; + } // End if zero if (len == 0xFF) break; + // Skip empty token + if (len == 0) continue; + // Receive response const auto token = co_await connection->read(len); + if (token.empty()) { + co_return Result::Error; + } // Run on_token callback co_await on_token(token); } + + // No error + co_return Result::Success; } diff --git a/Client.hpp b/Client.hpp index 1feafe6..ab2a627 100644 --- a/Client.hpp +++ b/Client.hpp @@ -43,7 +43,7 @@ class Client public: Client(const std::string &addr, unsigned port, AsyncManager& asyncManager); - basiccoro::AwaitableTask ask(std::string_view prompt, const std::function (unsigned progress)>& on_progress, const std::function (std::string_view token)>& on_token); + basiccoro::AwaitableTask ask(std::string_view prompt, const std::function (unsigned progress)>& on_progress, const std::function (std::string_view token)>& on_token); }; #endif // CLIENT_HPP diff --git a/Receiver.cpp b/Receiver.cpp index 30a08c0..7839c01 100644 --- a/Receiver.cpp +++ b/Receiver.cpp @@ -21,27 +21,27 @@ basiccoro::AwaitableTask Receiver::Simple::read(size_t amount) { // Return final buffer co_return fres; } -basiccoro::AwaitableTask Receiver::Simple::read(std::byte *buffer, size_t size) { +basiccoro::AwaitableTask Receiver::Simple::read(std::byte *buffer, size_t size) { size_t allBytesRead = 0; while (allBytesRead != size) { // Wait for data - if (co_await aMan.waitRead(fd)) [[unlikely]] { + if (co_await aMan.waitRead(fd) == Result::Error) [[unlikely]] { // Error - co_return true; + co_return Result::Error; } // Receive data ssize_t bytesRead; if ((bytesRead = recv(fd, reinterpret_cast(buffer+allBytesRead), size-allBytesRead, 0)) < 0) [[unlikely]] { // Error - co_return true; + co_return Result::Error; } allBytesRead += bytesRead; } // No error - co_return false; + co_return Result::Success; } basiccoro::AwaitableTask Receiver::Simple::readSome(size_t max) { @@ -50,7 +50,7 @@ basiccoro::AwaitableTask Receiver::Simple::readSome(size_t max) { fres.resize(max); // Wait for data - if (co_await aMan.waitRead(fd)) [[unlikely]] { + if (co_await aMan.waitRead(fd) == Result::Error) [[unlikely]] { co_return ""; } diff --git a/Receiver.hpp b/Receiver.hpp index 6d8060f..d5fd970 100644 --- a/Receiver.hpp +++ b/Receiver.hpp @@ -21,7 +21,7 @@ public: // Reads the exact amount of bytes given basiccoro::AwaitableTask read(size_t amount); - basiccoro::AwaitableTask read(std::byte *buffer, size_t size); + basiccoro::AwaitableTask read(std::byte *buffer, size_t size); // Reads at max. the amount of bytes given basiccoro::AwaitableTask readSome(size_t max); diff --git a/Sender.cpp b/Sender.cpp index 5327120..7af72ea 100644 --- a/Sender.cpp +++ b/Sender.cpp @@ -12,18 +12,18 @@ -basiccoro::AwaitableTask Sender::Simple::write(std::string_view str, bool moreData) { +basiccoro::AwaitableTask Sender::Simple::write(std::string_view str, bool moreData) { co_return co_await this->write(reinterpret_cast(str.data()), str.size(), moreData); } -basiccoro::AwaitableTask Sender::Simple::write(const std::byte *data, size_t size, bool moreData) { +basiccoro::AwaitableTask Sender::Simple::write(const std::byte *data, size_t size, bool moreData) { std::string fres; // Wait for socket to get ready for writing - if (co_await aMan.waitWrite(fd)) [[unlikely]] { - co_return true; + if (co_await aMan.waitWrite(fd) == Result::Error) [[unlikely]] { + co_return Result::Error; } // Write - co_return send(fd, reinterpret_cast(data), size, MSG_FLAGS_OR_ZERO(MSG_NOSIGNAL | (int(moreData)*MSG_MORE))) < 0; + co_return (send(fd, reinterpret_cast(data), size, MSG_FLAGS_OR_ZERO(MSG_NOSIGNAL | (int(moreData)*MSG_MORE))) < 0)?Result::Error:Result::Success; } diff --git a/Sender.hpp b/Sender.hpp index 404606a..c842d01 100644 --- a/Sender.hpp +++ b/Sender.hpp @@ -18,8 +18,8 @@ protected: public: Simple(AsyncManager& asyncManager, int fd) : aMan(asyncManager), fd(fd) {} - basiccoro::AwaitableTask write(std::string_view, bool moreData = false); - basiccoro::AwaitableTask write(const std::byte *data, size_t, bool moreData = false); + basiccoro::AwaitableTask write(std::string_view, bool moreData = false); + basiccoro::AwaitableTask write(const std::byte *data, size_t, bool moreData = false); template auto writeObject(const T& o, bool moreData = false) { diff --git a/main.cpp b/main.cpp index 02be45c..ebaf237 100644 --- a/main.cpp +++ b/main.cpp @@ -47,13 +47,13 @@ int main() { Runtime rt; AsyncManager aMan(rt); - // Start async main() - async_main(rt, aMan); - // Print header std::cout << "llama.any running on " PLATFORM ".\n" "\n"; + // Start async main() + async_main(rt, aMan); + // Start async manager aMan.run();