1
0
Fork 0
mirror of https://gitlab.com/niansa/llama_any.git synced 2025-03-06 20:48:27 +01:00

Added 'Result' type

This commit is contained in:
Nils Sauer 2023-04-06 10:00:27 +02:00
parent db3b5e343d
commit 5fca6fa023
9 changed files with 49 additions and 28 deletions

View file

@ -102,7 +102,7 @@ void AsyncManager::run() {
// Set futures // Set futures
for (auto& [future, value] : execQueue) { for (auto& [future, value] : execQueue) {
future->set(value); future->set(value?Result::Error:Result::Success);
future = nullptr; future = nullptr;
} }

View file

@ -11,10 +11,14 @@ class Runtime;
enum class Result {
Error,
Success
};
class AsyncManager { class AsyncManager {
public: public:
using SockError = bool; using SockFuture = basiccoro::SingleEvent<Result>;
using SockFuture = basiccoro::SingleEvent<SockError>;
using SockFutureUnique = std::unique_ptr<SockFuture>; using SockFutureUnique = std::unique_ptr<SockFuture>;
using SockFutureMap = std::unordered_multimap<int, SockFutureUnique>; using SockFutureMap = std::unordered_multimap<int, SockFutureUnique>;
@ -39,13 +43,13 @@ public:
stopping = true; stopping = true;
} }
basiccoro::AwaitableTask<SockError> waitRead(int fd) { basiccoro::AwaitableTask<Result> waitRead(int fd) {
auto event = std::make_unique<SockFuture>(); auto event = std::make_unique<SockFuture>();
auto eventPtr = event.get(); auto eventPtr = event.get();
sockReads.emplace(fd, std::move(event)); sockReads.emplace(fd, std::move(event));
co_return co_await *eventPtr; co_return co_await *eventPtr;
} }
basiccoro::AwaitableTask<SockError> waitWrite(int fd) { basiccoro::AwaitableTask<Result> waitWrite(int fd) {
auto event = std::make_unique<SockFuture>(); auto event = std::make_unique<SockFuture>();
auto eventPtr = event.get(); auto eventPtr = event.get();
sockWrites.emplace(fd, std::move(event)); sockWrites.emplace(fd, std::move(event));

View file

@ -72,22 +72,28 @@ Client::Client(const std::string& addr, unsigned port, AsyncManager& asyncManage
# endif # endif
} }
basiccoro::AwaitableTask<void> Client::ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token) { basiccoro::AwaitableTask<Result> Client::ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token) {
std::string fres; std::string fres;
// Send prompt length // Send prompt length
uint8_t len = prompt.length(); uint8_t len = prompt.length();
co_await connection->writeObject(len, true); if (co_await connection->writeObject(len, true) == Result::Error) {
co_return Result::Error;
}
// Send prompt // Send prompt
co_await connection->write(prompt); if (co_await connection->write(prompt) == Result::Error) {
co_return Result::Error;
}
// Receive progress // Receive progress
for (;;) { for (;;) {
uint8_t progress; uint8_t progress;
// Receive percentage // Receive percentage
co_await connection->readObject(progress); if (co_await connection->readObject(progress) == Result::Error) {
co_return Result::Error;
}
// Run on_progress callback // Run on_progress callback
co_await on_progress(progress); co_await on_progress(progress);
@ -99,15 +105,26 @@ basiccoro::AwaitableTask<void> Client::ask(std::string_view prompt, const std::f
// Receive response // Receive response
for (;;) { for (;;) {
// Receive response length // Receive response length
co_await connection->readObject(len); if (co_await connection->readObject(len) == Result::Error) {
co_return Result::Error;
}
// End if zero // End if zero
if (len == 0xFF) break; if (len == 0xFF) break;
// Skip empty token
if (len == 0) continue;
// Receive response // Receive response
const auto token = co_await connection->read(len); const auto token = co_await connection->read(len);
if (token.empty()) {
co_return Result::Error;
}
// Run on_token callback // Run on_token callback
co_await on_token(token); co_await on_token(token);
} }
// No error
co_return Result::Success;
} }

View file

@ -43,7 +43,7 @@ class Client
public: public:
Client(const std::string &addr, unsigned port, AsyncManager& asyncManager); Client(const std::string &addr, unsigned port, AsyncManager& asyncManager);
basiccoro::AwaitableTask<void> ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token); basiccoro::AwaitableTask<Result> ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token);
}; };
#endif // CLIENT_HPP #endif // CLIENT_HPP

View file

@ -21,27 +21,27 @@ basiccoro::AwaitableTask<std::string> Receiver::Simple::read(size_t amount) {
// Return final buffer // Return final buffer
co_return fres; co_return fres;
} }
basiccoro::AwaitableTask<AsyncManager::SockError> Receiver::Simple::read(std::byte *buffer, size_t size) { basiccoro::AwaitableTask<Result> Receiver::Simple::read(std::byte *buffer, size_t size) {
size_t allBytesRead = 0; size_t allBytesRead = 0;
while (allBytesRead != size) { while (allBytesRead != size) {
// Wait for data // Wait for data
if (co_await aMan.waitRead(fd)) [[unlikely]] { if (co_await aMan.waitRead(fd) == Result::Error) [[unlikely]] {
// Error // Error
co_return true; co_return Result::Error;
} }
// Receive data // Receive data
ssize_t bytesRead; ssize_t bytesRead;
if ((bytesRead = recv(fd, reinterpret_cast<char*>(buffer+allBytesRead), size-allBytesRead, 0)) < 0) [[unlikely]] { if ((bytesRead = recv(fd, reinterpret_cast<char*>(buffer+allBytesRead), size-allBytesRead, 0)) < 0) [[unlikely]] {
// Error // Error
co_return true; co_return Result::Error;
} }
allBytesRead += bytesRead; allBytesRead += bytesRead;
} }
// No error // No error
co_return false; co_return Result::Success;
} }
basiccoro::AwaitableTask<std::string> Receiver::Simple::readSome(size_t max) { basiccoro::AwaitableTask<std::string> Receiver::Simple::readSome(size_t max) {
@ -50,7 +50,7 @@ basiccoro::AwaitableTask<std::string> Receiver::Simple::readSome(size_t max) {
fres.resize(max); fres.resize(max);
// Wait for data // Wait for data
if (co_await aMan.waitRead(fd)) [[unlikely]] { if (co_await aMan.waitRead(fd) == Result::Error) [[unlikely]] {
co_return ""; co_return "";
} }

View file

@ -21,7 +21,7 @@ public:
// Reads the exact amount of bytes given // Reads the exact amount of bytes given
basiccoro::AwaitableTask<std::string> read(size_t amount); basiccoro::AwaitableTask<std::string> read(size_t amount);
basiccoro::AwaitableTask<AsyncManager::SockError> read(std::byte *buffer, size_t size); basiccoro::AwaitableTask<Result> read(std::byte *buffer, size_t size);
// Reads at max. the amount of bytes given // Reads at max. the amount of bytes given
basiccoro::AwaitableTask<std::string> readSome(size_t max); basiccoro::AwaitableTask<std::string> readSome(size_t max);

View file

@ -12,18 +12,18 @@
basiccoro::AwaitableTask<AsyncManager::SockError> Sender::Simple::write(std::string_view str, bool moreData) { basiccoro::AwaitableTask<Result> Sender::Simple::write(std::string_view str, bool moreData) {
co_return co_await this->write(reinterpret_cast<const std::byte*>(str.data()), str.size(), moreData); co_return co_await this->write(reinterpret_cast<const std::byte*>(str.data()), str.size(), moreData);
} }
basiccoro::AwaitableTask<AsyncManager::SockError> Sender::Simple::write(const std::byte *data, size_t size, bool moreData) { basiccoro::AwaitableTask<Result> Sender::Simple::write(const std::byte *data, size_t size, bool moreData) {
std::string fres; std::string fres;
// Wait for socket to get ready for writing // Wait for socket to get ready for writing
if (co_await aMan.waitWrite(fd)) [[unlikely]] { if (co_await aMan.waitWrite(fd) == Result::Error) [[unlikely]] {
co_return true; co_return Result::Error;
} }
// Write // Write
co_return send(fd, reinterpret_cast<const char*>(data), size, MSG_FLAGS_OR_ZERO(MSG_NOSIGNAL | (int(moreData)*MSG_MORE))) < 0; co_return (send(fd, reinterpret_cast<const char*>(data), size, MSG_FLAGS_OR_ZERO(MSG_NOSIGNAL | (int(moreData)*MSG_MORE))) < 0)?Result::Error:Result::Success;
} }

View file

@ -18,8 +18,8 @@ protected:
public: public:
Simple(AsyncManager& asyncManager, int fd) : aMan(asyncManager), fd(fd) {} Simple(AsyncManager& asyncManager, int fd) : aMan(asyncManager), fd(fd) {}
basiccoro::AwaitableTask<AsyncManager::SockError> write(std::string_view, bool moreData = false); basiccoro::AwaitableTask<Result> write(std::string_view, bool moreData = false);
basiccoro::AwaitableTask<AsyncManager::SockError> write(const std::byte *data, size_t, bool moreData = false); basiccoro::AwaitableTask<Result> write(const std::byte *data, size_t, bool moreData = false);
template<typename T> template<typename T>
auto writeObject(const T& o, bool moreData = false) { auto writeObject(const T& o, bool moreData = false) {

View file

@ -47,13 +47,13 @@ int main() {
Runtime rt; Runtime rt;
AsyncManager aMan(rt); AsyncManager aMan(rt);
// Start async main()
async_main(rt, aMan);
// Print header // Print header
std::cout << "llama.any running on " PLATFORM ".\n" std::cout << "llama.any running on " PLATFORM ".\n"
"\n"; "\n";
// Start async main()
async_main(rt, aMan);
// Start async manager // Start async manager
aMan.run(); aMan.run();