1
0
Fork 0
mirror of https://gitlab.com/niansa/llama_nds.git synced 2025-03-06 20:53:28 +01:00

Added 'Result' type

This commit is contained in:
Nils Sauer 2023-04-06 10:00:27 +02:00
parent db3b5e343d
commit 5fca6fa023
9 changed files with 49 additions and 28 deletions

View file

@ -102,7 +102,7 @@ void AsyncManager::run() {
// Set futures
for (auto& [future, value] : execQueue) {
future->set(value);
future->set(value?Result::Error:Result::Success);
future = nullptr;
}

View file

@ -11,10 +11,14 @@ class Runtime;
enum class Result {
Error,
Success
};
class AsyncManager {
public:
using SockError = bool;
using SockFuture = basiccoro::SingleEvent<SockError>;
using SockFuture = basiccoro::SingleEvent<Result>;
using SockFutureUnique = std::unique_ptr<SockFuture>;
using SockFutureMap = std::unordered_multimap<int, SockFutureUnique>;
@ -39,13 +43,13 @@ public:
stopping = true;
}
basiccoro::AwaitableTask<SockError> waitRead(int fd) {
basiccoro::AwaitableTask<Result> waitRead(int fd) {
auto event = std::make_unique<SockFuture>();
auto eventPtr = event.get();
sockReads.emplace(fd, std::move(event));
co_return co_await *eventPtr;
}
basiccoro::AwaitableTask<SockError> waitWrite(int fd) {
basiccoro::AwaitableTask<Result> waitWrite(int fd) {
auto event = std::make_unique<SockFuture>();
auto eventPtr = event.get();
sockWrites.emplace(fd, std::move(event));

View file

@ -72,22 +72,28 @@ Client::Client(const std::string& addr, unsigned port, AsyncManager& asyncManage
# endif
}
basiccoro::AwaitableTask<void> Client::ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token) {
basiccoro::AwaitableTask<Result> Client::ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token) {
std::string fres;
// Send prompt length
uint8_t len = prompt.length();
co_await connection->writeObject(len, true);
if (co_await connection->writeObject(len, true) == Result::Error) {
co_return Result::Error;
}
// Send prompt
co_await connection->write(prompt);
if (co_await connection->write(prompt) == Result::Error) {
co_return Result::Error;
}
// Receive progress
for (;;) {
uint8_t progress;
// Receive percentage
co_await connection->readObject(progress);
if (co_await connection->readObject(progress) == Result::Error) {
co_return Result::Error;
}
// Run on_progress callback
co_await on_progress(progress);
@ -99,15 +105,26 @@ basiccoro::AwaitableTask<void> Client::ask(std::string_view prompt, const std::f
// Receive response
for (;;) {
// Receive response length
co_await connection->readObject(len);
if (co_await connection->readObject(len) == Result::Error) {
co_return Result::Error;
}
// End if zero
if (len == 0xFF) break;
// Skip empty token
if (len == 0) continue;
// Receive response
const auto token = co_await connection->read(len);
if (token.empty()) {
co_return Result::Error;
}
// Run on_token callback
co_await on_token(token);
}
// No error
co_return Result::Success;
}

View file

@ -43,7 +43,7 @@ class Client
public:
Client(const std::string &addr, unsigned port, AsyncManager& asyncManager);
basiccoro::AwaitableTask<void> ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token);
basiccoro::AwaitableTask<Result> ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token);
};
#endif // CLIENT_HPP

View file

@ -21,27 +21,27 @@ basiccoro::AwaitableTask<std::string> Receiver::Simple::read(size_t amount) {
// Return final buffer
co_return fres;
}
basiccoro::AwaitableTask<AsyncManager::SockError> Receiver::Simple::read(std::byte *buffer, size_t size) {
basiccoro::AwaitableTask<Result> Receiver::Simple::read(std::byte *buffer, size_t size) {
size_t allBytesRead = 0;
while (allBytesRead != size) {
// Wait for data
if (co_await aMan.waitRead(fd)) [[unlikely]] {
if (co_await aMan.waitRead(fd) == Result::Error) [[unlikely]] {
// Error
co_return true;
co_return Result::Error;
}
// Receive data
ssize_t bytesRead;
if ((bytesRead = recv(fd, reinterpret_cast<char*>(buffer+allBytesRead), size-allBytesRead, 0)) < 0) [[unlikely]] {
// Error
co_return true;
co_return Result::Error;
}
allBytesRead += bytesRead;
}
// No error
co_return false;
co_return Result::Success;
}
basiccoro::AwaitableTask<std::string> Receiver::Simple::readSome(size_t max) {
@ -50,7 +50,7 @@ basiccoro::AwaitableTask<std::string> Receiver::Simple::readSome(size_t max) {
fres.resize(max);
// Wait for data
if (co_await aMan.waitRead(fd)) [[unlikely]] {
if (co_await aMan.waitRead(fd) == Result::Error) [[unlikely]] {
co_return "";
}

View file

@ -21,7 +21,7 @@ public:
// Reads the exact amount of bytes given
basiccoro::AwaitableTask<std::string> read(size_t amount);
basiccoro::AwaitableTask<AsyncManager::SockError> read(std::byte *buffer, size_t size);
basiccoro::AwaitableTask<Result> read(std::byte *buffer, size_t size);
// Reads at max. the amount of bytes given
basiccoro::AwaitableTask<std::string> readSome(size_t max);

View file

@ -12,18 +12,18 @@
basiccoro::AwaitableTask<AsyncManager::SockError> Sender::Simple::write(std::string_view str, bool moreData) {
basiccoro::AwaitableTask<Result> Sender::Simple::write(std::string_view str, bool moreData) {
co_return co_await this->write(reinterpret_cast<const std::byte*>(str.data()), str.size(), moreData);
}
basiccoro::AwaitableTask<AsyncManager::SockError> Sender::Simple::write(const std::byte *data, size_t size, bool moreData) {
basiccoro::AwaitableTask<Result> Sender::Simple::write(const std::byte *data, size_t size, bool moreData) {
std::string fres;
// Wait for socket to get ready for writing
if (co_await aMan.waitWrite(fd)) [[unlikely]] {
co_return true;
if (co_await aMan.waitWrite(fd) == Result::Error) [[unlikely]] {
co_return Result::Error;
}
// Write
co_return send(fd, reinterpret_cast<const char*>(data), size, MSG_FLAGS_OR_ZERO(MSG_NOSIGNAL | (int(moreData)*MSG_MORE))) < 0;
co_return (send(fd, reinterpret_cast<const char*>(data), size, MSG_FLAGS_OR_ZERO(MSG_NOSIGNAL | (int(moreData)*MSG_MORE))) < 0)?Result::Error:Result::Success;
}

View file

@ -18,8 +18,8 @@ protected:
public:
Simple(AsyncManager& asyncManager, int fd) : aMan(asyncManager), fd(fd) {}
basiccoro::AwaitableTask<AsyncManager::SockError> write(std::string_view, bool moreData = false);
basiccoro::AwaitableTask<AsyncManager::SockError> write(const std::byte *data, size_t, bool moreData = false);
basiccoro::AwaitableTask<Result> write(std::string_view, bool moreData = false);
basiccoro::AwaitableTask<Result> write(const std::byte *data, size_t, bool moreData = false);
template<typename T>
auto writeObject(const T& o, bool moreData = false) {

View file

@ -47,13 +47,13 @@ int main() {
Runtime rt;
AsyncManager aMan(rt);
// Start async main()
async_main(rt, aMan);
// Print header
std::cout << "llama.any running on " PLATFORM ".\n"
"\n";
// Start async main()
async_main(rt, aMan);
// Start async manager
aMan.run();