1
0
Fork 0
mirror of https://gitlab.com/niansa/llama_nds.git synced 2025-03-06 20:53:28 +01:00

Use basic-coro for coroutines

This commit is contained in:
Nils Sauer 2023-04-06 09:07:53 +02:00
parent 76d5f397ff
commit f303aca0c9
16 changed files with 734 additions and 55 deletions

114
AsyncManager.cpp Normal file
View file

@ -0,0 +1,114 @@
#include "AsyncManager.hpp"
#include "Runtime.hpp"
#include <vector>
#include <utility>
#ifndef PLATFORM_WINDOWS
# include <sys/select.h>
#else
# include <ws2tcpip.h>
#endif
void AsyncManager::cleanFutureMap(SockFutureMap& map) {
std::vector<SockFutureMap::iterator> erasureQueue;
for (auto it = map.begin(); it != map.end(); it++) {
if (!it->second) [[unlikely]] {
erasureQueue.push_back(it);
}
}
for (auto& it : erasureQueue) {
map.erase(it);
}
}
void AsyncManager::run() {
while (!stopping && runtime.cooperate()) {
// We should stop once there is nothihng left to do
if (sockReads.empty() && sockWrites.empty()) [[unlikely]] {
break;
}
// We need to keep track of the highest fd for socket()
int maxFd = 0;
// Create except FD set
fd_set exceptFds;
FD_ZERO(&exceptFds);
// Create write FD set
fd_set writeFds;
FD_ZERO(&writeFds);
for (const auto& [fd, cb] : sockWrites) {
FD_SET(fd, &writeFds);
FD_SET(fd, &exceptFds);
if (fd > maxFd) {
maxFd = fd;
}
}
// Create read FD set
fd_set readFds;
FD_ZERO(&readFds);
for (const auto& [fd, cb] : sockReads) {
FD_SET(fd, &readFds);
FD_SET(fd, &exceptFds);
if (fd > maxFd) {
maxFd = fd;
}
}
// Specify timeout
timeval tv{
.tv_sec = 0,
.tv_usec = 250000
};
// select() until there is data
bool error = false;
if (select(maxFd+1, &readFds, &writeFds, &exceptFds, &tv) < 0) {
FD_ZERO(&readFds);
FD_ZERO(&writeFds);
error = true;
}
// Execution queue
std::vector<std::pair<SockFutureUnique&, bool>> execQueue;
// Collect all write futures
for (auto& [fd, future] : sockWrites) {
if (FD_ISSET(fd, &writeFds)) {
// Socket is ready for writing
execQueue.push_back({future, false});
}
if (FD_ISSET(fd, &exceptFds) || error) [[unlikely]] {
// An exception happened in the socket
execQueue.push_back({future, true});
}
}
// Collect all read futures
for (auto& [fd, future] : sockReads) {
if (FD_ISSET(fd, &readFds)) {
// Socket is ready for reading
execQueue.push_back({future, false});
}
if (FD_ISSET(fd, &exceptFds) || error) [[unlikely]] {
// An exception happened in the socket
execQueue.push_back({future, true});
}
}
// Set futures
for (auto& [future, value] : execQueue) {
future->set(value);
future = nullptr;
}
// Clean future maps
cleanFutureMap(sockWrites);
cleanFutureMap(sockReads);
}
stopping = false;
}

59
AsyncManager.hpp Normal file
View file

@ -0,0 +1,59 @@
#ifndef _ASYNCMANAGER_HPP
#define _ASYNCMANAGER_HPP
#include "Runtime.hpp"
#include <unordered_map>
#include <memory>
#include "basic-coro/AwaitableTask.hpp"
#include "basic-coro/SingleEvent.hpp"
class Runtime;
class AsyncManager {
public:
using SockError = bool;
using SockFuture = basiccoro::SingleEvent<SockError>;
using SockFutureUnique = std::unique_ptr<SockFuture>;
using SockFutureMap = std::unordered_multimap<int, SockFutureUnique>;
private:
Runtime& runtime;
SockFutureMap sockReads;
SockFutureMap sockWrites;
bool stopping = false;
static
void cleanFutureMap(SockFutureMap&);
public:
AsyncManager(Runtime& runtime) : runtime(runtime) {}
AsyncManager(AsyncManager&) = delete;
AsyncManager(const AsyncManager&) = delete;
AsyncManager(AsyncManager&&) = delete;
void run();
void stop() {
stopping = true;
}
basiccoro::AwaitableTask<SockError> waitRead(int fd) {
auto event = std::make_unique<SockFuture>();
auto eventPtr = event.get();
sockReads.emplace(fd, std::move(event));
co_return co_await *eventPtr;
}
basiccoro::AwaitableTask<SockError> waitWrite(int fd) {
auto event = std::make_unique<SockFuture>();
auto eventPtr = event.get();
sockWrites.emplace(fd, std::move(event));
co_return co_await *eventPtr;
}
auto& getRuntime() const {
return runtime;
}
};
#endif

View file

@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.5)
project(llama.any LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
add_executable(llama
@ -11,7 +11,9 @@ add_executable(llama
Socket.hpp
Receiver.hpp Receiver.cpp
Sender.hpp Sender.cpp
AsyncManager.hpp AsyncManager.cpp
Runtime.cpp Runtime.hpp
basic-coro/AwaitableTask.hpp basic-coro/SingleEvent.hpp basic-coro/SingleEvent.cpp
)
target_compile_definitions(llama PUBLIC PLATFORM="${CMAKE_SYSTEM_NAME}")

View file

@ -45,9 +45,9 @@ void Client::fetchAddr(const std::string& addr, unsigned port) {
}
}
Client::Client(const std::string& addr, unsigned port) {
Client::Client(const std::string& addr, unsigned port, AsyncManager& asyncManager) : aMan(asyncManager) {
// Create socket
connection = std::make_unique<SocketConnection<Sender::Simple, Receiver::Simple>>(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); //TODO: Care about IPv6
connection = std::make_unique<SocketConnection<Sender::Simple, Receiver::Simple>>(aMan, Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); //TODO: Care about IPv6
if (*connection < 0) [[unlikely]] {
throw Exception("Failed to create TCP socket");
}
@ -72,25 +72,25 @@ Client::Client(const std::string& addr, unsigned port) {
# endif
}
void Client::ask(std::string_view prompt, const std::function<void (unsigned progress)>& on_progress, const std::function<void (std::string_view token)>& on_token) {
basiccoro::AwaitableTask<void> Client::ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token) {
std::string fres;
// Send prompt length
uint8_t len = prompt.length();
connection->writeObject(len, true);
co_await connection->writeObject(len, true);
// Send prompt
connection->write(prompt);
co_await connection->write(prompt);
// Receive progress
for (;;) {
uint8_t progress;
// Receive percentage
connection->readObject(progress);
co_await connection->readObject(progress);
// Run on_progress callback
on_progress(progress);
co_await on_progress(progress);
// Stop at 100%
if (progress == 100) break;
@ -99,15 +99,15 @@ void Client::ask(std::string_view prompt, const std::function<void (unsigned pro
// Receive response
for (;;) {
// Receive response length
connection->readObject(len);
co_await connection->readObject(len);
// End if zero
if (len == 0xFF) break;
// Receive response
const auto token = connection->read(len);
const auto token = co_await connection->read(len);
// Run on_token callback
on_token(token);
co_await on_token(token);
}
}

View file

@ -1,9 +1,11 @@
#ifndef CLIENT_HPP
#define CLIENT_HPP
#include "Runtime.hpp"
#include "AsyncManager.hpp"
#include "Socket.hpp"
#include "Sender.hpp"
#include "Receiver.hpp"
#include "basic-coro/AwaitableTask.hpp"
#include <string>
#include <string_view>
@ -23,6 +25,8 @@ class Client
using std::runtime_error::runtime_error;
};
AsyncManager& aMan;
int fd = -1;
# ifdef HAS_ADDRINFO
addrinfo
@ -37,9 +41,9 @@ class Client
void fetchAddr(const std::string& addr, unsigned port);
public:
Client(const std::string &addr, unsigned port);
Client(const std::string &addr, unsigned port, AsyncManager& asyncManager);
void ask(std::string_view prompt, const std::function<void (unsigned progress)>& on_progress, const std::function<void (std::string_view token)>& on_token);
basiccoro::AwaitableTask<void> ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token);
};
#endif // CLIENT_HPP

View file

@ -4,40 +4,63 @@
#include <array>
#ifndef PLATFORM_WINDOWS
# include <sys/select.h>
# include <sys/socket.h>
#else
# include <ws2tcpip.h>
#endif
std::string Receiver::Simple::read(size_t amount) {
basiccoro::AwaitableTask<std::string> Receiver::Simple::read(size_t amount) {
// Create buffer
std::string fres;
fres.resize(amount);
// Read into buffer
read(reinterpret_cast<std::byte*>(fres.data()), fres.size());
co_await read(reinterpret_cast<std::byte*>(fres.data()), fres.size());
// Return final buffer
return fres;
co_return fres;
}
void Receiver::Simple::read(std::byte *buffer, size_t size) {
recv(fd, reinterpret_cast<char*>(buffer), size, MSG_WAITALL);
basiccoro::AwaitableTask<AsyncManager::SockError> Receiver::Simple::read(std::byte *buffer, size_t size) {
size_t allBytesRead = 0;
while (allBytesRead != size) {
// Wait for data
if (co_await aMan.waitRead(fd)) [[unlikely]] {
// Error
co_return true;
}
std::string Receiver::Simple::readSome(size_t max) {
// Receive data
ssize_t bytesRead;
if ((bytesRead = recv(fd, reinterpret_cast<char*>(buffer+allBytesRead), size-allBytesRead, 0)) < 0) [[unlikely]] {
// Error
co_return true;
}
allBytesRead += bytesRead;
}
// No error
co_return false;
}
basiccoro::AwaitableTask<std::string> Receiver::Simple::readSome(size_t max) {
// Create buffer
std::string fres;
fres.resize(max);
// Wait for data
if (co_await aMan.waitRead(fd)) [[unlikely]] {
co_return "";
}
// Receive data
ssize_t bytesRead;
if ((bytesRead = recv(fd, fres.data(), max, MSG_WAITALL)) < 0) [[unlikely]] {
return "";
if ((bytesRead = recv(fd, fres.data(), max, 0)) < 0) [[unlikely]] {
co_return "";
}
// Resize and return final buffer
fres.resize(bytesRead);
return fres;
co_return fres;
}

View file

@ -1,24 +1,29 @@
#ifndef _RECEIVER_HPP
#define _RECEIVER_HPP
#include "Runtime.hpp"
#include "AsyncManager.hpp"
#include <string>
#include <cstddef>
#include "basic-coro/AwaitableTask.hpp"
namespace Receiver {
class Simple {
Runtime& runtime;
AsyncManager &aMan;
protected:
int fd;
public:
Simple(int fd) : fd(fd) {}
Simple(AsyncManager& asyncManager, int fd) : runtime(asyncManager.getRuntime()), aMan(asyncManager), fd(fd) {}
// Reads the exact amount of bytes given
std::string read(size_t amount);
void read(std::byte *buffer, size_t size);
basiccoro::AwaitableTask<std::string> read(size_t amount);
basiccoro::AwaitableTask<AsyncManager::SockError> read(std::byte *buffer, size_t size);
// Reads at max. the amount of bytes given
std::string readSome(size_t max);
basiccoro::AwaitableTask<std::string> readSome(size_t max);
// Reads an object of type T
template<typename T>

View file

@ -1,9 +1,9 @@
#include "Runtime.hpp"
#include "Sender.hpp"
#include "basic-coro/AwaitableTask.hpp"
#include <string_view>
#ifndef PLATFORM_WINDOWS
# include <sys/socket.h>
# include <sys/select.h>
#else
# include <ws2tcpip.h>
@ -12,13 +12,18 @@
void Sender::Simple::write(std::string_view str, bool moreData) {
this->write(reinterpret_cast<const std::byte*>(str.data()), str.size(), moreData);
basiccoro::AwaitableTask<AsyncManager::SockError> Sender::Simple::write(std::string_view str, bool moreData) {
co_return co_await this->write(reinterpret_cast<const std::byte*>(str.data()), str.size(), moreData);
}
void Sender::Simple::write(const std::byte *data, size_t size, bool moreData) {
basiccoro::AwaitableTask<AsyncManager::SockError> Sender::Simple::write(const std::byte *data, size_t size, bool moreData) {
std::string fres;
// Write
send(fd, reinterpret_cast<const char*>(data), size, MSG_FLAGS_OR_ZERO(MSG_WAITALL | MSG_NOSIGNAL | (int(moreData)*MSG_MORE)));
// Wait for socket to get ready for writing
if (co_await aMan.waitWrite(fd)) [[unlikely]] {
co_return true;
}
// Write
co_return send(fd, reinterpret_cast<const char*>(data), size, MSG_FLAGS_OR_ZERO(MSG_NOSIGNAL | (int(moreData)*MSG_MORE))) < 0;
}

View file

@ -1,20 +1,25 @@
#ifndef _SENDER_HPP
#define _SENDER_HPP
#include "AsyncManager.hpp"
#include <string_view>
#include <vector>
#include <cstddef>
#include "basic-coro/AwaitableTask.hpp"
namespace Sender {
class Simple {
AsyncManager &aMan;
protected:
int fd;
public:
Simple(int fd) : fd(fd) {}
Simple(AsyncManager& asyncManager, int fd) : aMan(asyncManager), fd(fd) {}
void write(std::string_view, bool moreData = false);
void write(const std::byte *data, size_t, bool moreData = false);
basiccoro::AwaitableTask<AsyncManager::SockError> write(std::string_view, bool moreData = false);
basiccoro::AwaitableTask<AsyncManager::SockError> write(const std::byte *data, size_t, bool moreData = false);
template<typename T>
auto writeObject(const T& o, bool moreData = false) {

View file

@ -1,14 +1,14 @@
#ifndef _SOCKET_HPP
#define _SOCKET_HPP
#include "AsyncManager.hpp"
#include <memory>
#include <unistd.h>
#if defined(PLATFORM_WINDOWS)
# include <ws2tcpip.h>
#elif defined(PLATFORM_WII)
#include <network.h>
#else
#ifndef PLATFORM_WINDOWS
# include <sys/socket.h>
# include <sys/select.h>
#else
# include <ws2tcpip.h>
#endif
@ -58,9 +58,9 @@ public:
template<class SenderT, class ReceiverT>
class SocketConnection : public SenderT, public ReceiverT, public Socket {
public:
SocketConnection(Socket&& socket)
SocketConnection(AsyncManager& asyncManager, Socket&& socket)
// Double-initialization seems to yield better assembly
: SenderT(socket), ReceiverT(socket), Socket(std::move(socket)) {
: SenderT(asyncManager, socket), ReceiverT(asyncManager, socket), Socket(std::move(socket)) {
SenderT::fd = get();
ReceiverT::fd = get();
}

View file

@ -0,0 +1,195 @@
#pragma once
#include <concepts>
#include <coroutine>
#include <exception>
#include <stdexcept>
#include <utility>
namespace basiccoro
{
namespace detail
{
template<class Derived>
struct PromiseBase
{
auto get_return_object() { return std::coroutine_handle<Derived>::from_promise(static_cast<Derived&>(*this)); }
void unhandled_exception() { std::terminate(); }
};
template<class Derived, class T> requires std::movable<T> || std::same_as<T, void>
struct ValuePromise : public PromiseBase<Derived>
{
using value_type = T;
T val;
void return_value(T t) { val = std::move(t); }
};
template<class Derived>
struct ValuePromise<Derived, void> : public PromiseBase<Derived>
{
using value_type = void;
void return_void() {}
};
template<class T>
class AwaitablePromise : public ValuePromise<AwaitablePromise<T>, T>
{
public:
auto initial_suspend() { return std::suspend_never(); }
auto final_suspend() noexcept
{
if (waiting_)
{
waiting_.resume();
if (waiting_.done())
{
waiting_.destroy();
}
waiting_ = nullptr;
}
return std::suspend_always();
}
void storeWaiting(std::coroutine_handle<> handle)
{
if (waiting_)
{
throw std::runtime_error("AwaitablePromise::storeWaiting(): already waiting");
}
waiting_ = handle;
}
~AwaitablePromise()
{
if (waiting_)
{
waiting_.destroy();
}
}
private:
std::coroutine_handle<> waiting_ = nullptr;
};
template<class Promise>
class TaskBase
{
public:
using promise_type = Promise;
TaskBase();
TaskBase(std::coroutine_handle<promise_type> handle);
TaskBase(const TaskBase&) = delete;
TaskBase(TaskBase&&);
TaskBase& operator=(const TaskBase&) = delete;
TaskBase& operator=(TaskBase&&);
~TaskBase();
bool done() const { return handle_.done(); }
protected:
std::coroutine_handle<promise_type> handle_;
bool handleShouldBeDestroyed_;
};
template<class Promise>
TaskBase<Promise>::TaskBase()
: handle_(nullptr), handleShouldBeDestroyed_(false)
{}
template<class Promise>
TaskBase<Promise>::TaskBase(std::coroutine_handle<promise_type> handle)
: handle_(handle)
{
// TODO: this whole system needs revamping with something like UniqueCoroutineHandle
// and custom static interface to awaiter types - so await_suspend method would take in UniqueCoroutineHandle
if (handle.done())
{
// it is resonable to expect that if the coroutine is done before
// the task creation, then the original stack is continued without suspending,
// and coroutine needs to be destroyed with TaskBase object
handleShouldBeDestroyed_ = true;
}
else
{
// otherwise the coroutine should be managed by object that it is awaiting
handleShouldBeDestroyed_ = false;
}
}
template<class Promise>
TaskBase<Promise>::TaskBase(TaskBase&& other)
: handle_(other.handle_), handleShouldBeDestroyed_(std::exchange(other.handleShouldBeDestroyed_, false))
{
}
template<class Promise>
TaskBase<Promise>& TaskBase<Promise>::operator=(TaskBase&& other)
{
handle_ = other.handle_;
handleShouldBeDestroyed_ = std::exchange(other.handleShouldBeDestroyed_, false);
return *this;
}
template<class Promise>
TaskBase<Promise>::~TaskBase()
{
if (handleShouldBeDestroyed_)
{
handle_.destroy();
}
}
} // namespace detail
template<class T>
class AwaitableTask : public detail::TaskBase<detail::AwaitablePromise<T>>
{
using Base = detail::TaskBase<detail::AwaitablePromise<T>>;
public:
using Base::Base;
class awaiter;
friend class awaiter;
awaiter operator co_await() const;
};
template<class T>
struct AwaitableTask<T>::awaiter
{
bool await_ready()
{
return task_.done();
}
template<class Promise>
void await_suspend(std::coroutine_handle<Promise> handle)
{
task_.handle_.promise().storeWaiting(handle);
}
T await_resume()
{
if constexpr (!std::is_same_v<void, T>)
{
return std::move(task_.handle_.promise().val);
}
}
const AwaitableTask& task_;
};
template<class T>
typename AwaitableTask<T>::awaiter AwaitableTask<T>::operator co_await() const
{
return awaiter{*this};
}
} // namespace basiccoro

21
basic-coro/LICENSE Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021 Maksymilian Kadukowski
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

68
basic-coro/README.md Normal file
View file

@ -0,0 +1,68 @@
# basic-coro - c++ coroutine library
Library that implements helper types for using c++ coroutines. Please be aware that this is a training project for me - I wanted to learn more about CMake, gtest and git submodules.
## Usage
### Prerequisites
* g++-10
### Installing
```
mkdir build && cd build
cmake -D CMAKE_CXX_COMPILER=g++-10 ..
make install
```
This will install appropriate headers into `./include/` and static linked library into `./lib/`.
### Classes
Library includes following classes:
* `SingleEvent<T>` which models `co_await` enabled event that can be set,
* `AwaitableTask<T>` which models `co_await` enabled task.
Please note that these classes are not multithreading enabled. There is no synchronization or any kind of protection form race conditions. If you need to use coroutines with multithreading, just use [CppCoro](https://github.com/lewissbaker/cppcoro). This library is mostly thought for use with simple GUI programming.
### Example
```c++
#include <iostream>
#include <basiccoro/AwaitableTask.hpp>
#include <basiccoro/SingleEvent.hpp>
basiccoro::AwaitableTask<void> consumer(basiccoro::SingleEvent<int>& event)
{
std::cout << "consumer: start waiting" << std::endl;
while (true)
{
const auto i = co_await event;
std::cout << "consumer: received: " << i << std::endl;
}
}
int main()
{
basiccoro::SingleEvent<int> event;
consumer(event);
while (true)
{
int i = 0;
std::cout << "Enter no.(1-9): ";
std::cin >> i;
if (i == 0)
{
break;
}
else if (1 <= i && i <= 9)
{
event.set(i);
}
}
}
```
Simple example highlighting use of coroutines in producer-consumer problem.
## Acknowledgments
* [CMake C++ Project Template](https://github.com/kigster/cmake-project-template) as this project is based on this template
* Lewis Baker has excellent [articles](https://lewissbaker.github.io/) on topic of coroutines and assymetric transfer. This project is mostly based on information (and code snippets) contained in those articles.

View file

@ -0,0 +1,58 @@
#include "SingleEvent.hpp"
namespace basiccoro
{
detail::SingleEventBase::SingleEventBase(detail::SingleEventBase&& other)
: waiting_(std::move(other.waiting_))
, isSet_(std::exchange(other.isSet_, false))
{
}
detail::SingleEventBase& detail::SingleEventBase::operator=(detail::SingleEventBase&& other)
{
waiting_ = std::move(other.waiting_);
isSet_ = std::exchange(other.isSet_, false);
return *this;
}
detail::SingleEventBase::~SingleEventBase()
{
for (auto handle : waiting_)
{
handle.destroy();
}
}
void detail::SingleEventBase::set_common()
{
if (!isSet_)
{
if (waiting_.empty())
{
isSet_ = true;
}
else
{
// resuming coroutines can result in
// consecutive co_awaits on this object
auto temp = std::move(waiting_);
for (auto handle : temp)
{
handle.resume();
if (handle.done())
{
handle.destroy();
}
}
}
}
}
SingleEvent<void>::awaiter SingleEvent<void>::operator co_await()
{
return awaiter{*this};
}
} // namespace basiccoro

109
basic-coro/SingleEvent.hpp Normal file
View file

@ -0,0 +1,109 @@
#pragma once
#include <coroutine>
#include <optional>
#include <stdexcept>
#include <type_traits>
#include <vector>
namespace basiccoro
{
namespace detail
{
template<class Event>
class AwaiterBase
{
public:
AwaiterBase(Event& event)
: event_(event)
{}
bool await_ready()
{
if (event_.isSet())
{
// unset already set event, then continue coroutine
event_.isSet_ = false;
return true;
}
return false;
}
void await_suspend(std::coroutine_handle<> handle)
{
event_.waiting_.push_back(handle);
}
typename Event::value_type await_resume()
{
if constexpr (!std::is_same_v<typename Event::value_type, void>)
{
if (!event_.result)
{
throw std::runtime_error("AwaiterBase: no value in event_.result");
}
return *event_.result;
}
}
private:
Event& event_;
};
class SingleEventBase
{
public:
SingleEventBase() = default;
SingleEventBase(const SingleEventBase&) = delete;
SingleEventBase(SingleEventBase&&);
SingleEventBase& operator=(const SingleEventBase&) = delete;
SingleEventBase& operator=(SingleEventBase&&);
~SingleEventBase();
bool isSet() const {
return isSet_;
}
protected:
void set_common();
private:
template<class T>
friend class AwaiterBase;
std::vector<std::coroutine_handle<>> waiting_;
bool isSet_ = false;
};
} // namespace detail
template<class T>
class SingleEvent : public detail::SingleEventBase
{
public:
using value_type = T;
using awaiter = detail::AwaiterBase<SingleEvent<T>>;
void set(T t) { result = std::move(t); set_common(); }
awaiter operator co_await() { return awaiter{*this}; }
private:
friend awaiter;
std::optional<T> result;
};
template<>
class SingleEvent<void> : public detail::SingleEventBase
{
public:
using value_type = void;
using awaiter = detail::AwaiterBase<SingleEvent<void>>;
void set() {
set_common();
}
awaiter operator co_await();
};
} // namespace basiccoro

View file

@ -1,5 +1,7 @@
#include "Runtime.hpp"
#include "AsyncManager.hpp"
#include "Client.hpp"
#include "basic-coro/AwaitableTask.hpp"
#include <iostream>
#include <string>
@ -11,19 +13,12 @@ void on_progress(float progress) {
std::cout << unsigned(progress) << '\r' << std::flush;
}
int main()
{
Runtime rt;
// Print header
std::cout << "llama.any running on " PLATFORM ".\n"
"\n";
basiccoro::AwaitableTask<void> async_main(Runtime& rt, AsyncManager &aMan) {
// Ask for server address
const std::string addr = rt.readInput("Server address");
// Create client
Client client(addr, 99181);
Client client(addr, 99181, aMan);
// Connection loop
for (;; rt.cooperate()) {
@ -37,15 +32,31 @@ int main()
std::cout << "Prompt: " << prompt << std::endl;
// Run inference
client.ask(prompt, [&rt] (float progress) {
co_await client.ask(prompt, [&rt] (float progress) -> basiccoro::AwaitableTask<void> {
std::cout << unsigned(progress) << "%\r" << std::flush;
rt.cooperate();
}, [&rt] (std::string_view token) {
co_return;
}, [&rt] (std::string_view token) -> basiccoro::AwaitableTask<void> {
std::cout << token << std::flush;
rt.cooperate();
co_return;
});
std::cout << "\n";
}
}
int main()
{
Runtime rt;
AsyncManager aMan(rt);
// Start async main()
async_main(rt, aMan);
// Print header
std::cout << "llama.any running on " PLATFORM ".\n"
"\n";
// Start async manager
aMan.run();
return 0;
}