mirror of
https://gitlab.com/niansa/llama_nds.git
synced 2025-03-06 20:53:28 +01:00
Use basic-coro for coroutines
This commit is contained in:
parent
76d5f397ff
commit
f303aca0c9
16 changed files with 734 additions and 55 deletions
114
AsyncManager.cpp
Normal file
114
AsyncManager.cpp
Normal file
|
@ -0,0 +1,114 @@
|
|||
#include "AsyncManager.hpp"
|
||||
#include "Runtime.hpp"
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
# include <sys/select.h>
|
||||
#else
|
||||
# include <ws2tcpip.h>
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
void AsyncManager::cleanFutureMap(SockFutureMap& map) {
|
||||
std::vector<SockFutureMap::iterator> erasureQueue;
|
||||
for (auto it = map.begin(); it != map.end(); it++) {
|
||||
if (!it->second) [[unlikely]] {
|
||||
erasureQueue.push_back(it);
|
||||
}
|
||||
}
|
||||
for (auto& it : erasureQueue) {
|
||||
map.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
void AsyncManager::run() {
|
||||
while (!stopping && runtime.cooperate()) {
|
||||
// We should stop once there is nothihng left to do
|
||||
if (sockReads.empty() && sockWrites.empty()) [[unlikely]] {
|
||||
break;
|
||||
}
|
||||
|
||||
// We need to keep track of the highest fd for socket()
|
||||
int maxFd = 0;
|
||||
|
||||
// Create except FD set
|
||||
fd_set exceptFds;
|
||||
FD_ZERO(&exceptFds);
|
||||
|
||||
// Create write FD set
|
||||
fd_set writeFds;
|
||||
FD_ZERO(&writeFds);
|
||||
for (const auto& [fd, cb] : sockWrites) {
|
||||
FD_SET(fd, &writeFds);
|
||||
FD_SET(fd, &exceptFds);
|
||||
if (fd > maxFd) {
|
||||
maxFd = fd;
|
||||
}
|
||||
}
|
||||
|
||||
// Create read FD set
|
||||
fd_set readFds;
|
||||
FD_ZERO(&readFds);
|
||||
for (const auto& [fd, cb] : sockReads) {
|
||||
FD_SET(fd, &readFds);
|
||||
FD_SET(fd, &exceptFds);
|
||||
if (fd > maxFd) {
|
||||
maxFd = fd;
|
||||
}
|
||||
}
|
||||
|
||||
// Specify timeout
|
||||
timeval tv{
|
||||
.tv_sec = 0,
|
||||
.tv_usec = 250000
|
||||
};
|
||||
|
||||
// select() until there is data
|
||||
bool error = false;
|
||||
if (select(maxFd+1, &readFds, &writeFds, &exceptFds, &tv) < 0) {
|
||||
FD_ZERO(&readFds);
|
||||
FD_ZERO(&writeFds);
|
||||
error = true;
|
||||
}
|
||||
|
||||
// Execution queue
|
||||
std::vector<std::pair<SockFutureUnique&, bool>> execQueue;
|
||||
|
||||
// Collect all write futures
|
||||
for (auto& [fd, future] : sockWrites) {
|
||||
if (FD_ISSET(fd, &writeFds)) {
|
||||
// Socket is ready for writing
|
||||
execQueue.push_back({future, false});
|
||||
}
|
||||
if (FD_ISSET(fd, &exceptFds) || error) [[unlikely]] {
|
||||
// An exception happened in the socket
|
||||
execQueue.push_back({future, true});
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all read futures
|
||||
for (auto& [fd, future] : sockReads) {
|
||||
if (FD_ISSET(fd, &readFds)) {
|
||||
// Socket is ready for reading
|
||||
execQueue.push_back({future, false});
|
||||
}
|
||||
if (FD_ISSET(fd, &exceptFds) || error) [[unlikely]] {
|
||||
// An exception happened in the socket
|
||||
execQueue.push_back({future, true});
|
||||
}
|
||||
}
|
||||
|
||||
// Set futures
|
||||
for (auto& [future, value] : execQueue) {
|
||||
future->set(value);
|
||||
future = nullptr;
|
||||
}
|
||||
|
||||
// Clean future maps
|
||||
cleanFutureMap(sockWrites);
|
||||
cleanFutureMap(sockReads);
|
||||
}
|
||||
stopping = false;
|
||||
}
|
59
AsyncManager.hpp
Normal file
59
AsyncManager.hpp
Normal file
|
@ -0,0 +1,59 @@
|
|||
#ifndef _ASYNCMANAGER_HPP
|
||||
#define _ASYNCMANAGER_HPP
|
||||
#include "Runtime.hpp"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include "basic-coro/AwaitableTask.hpp"
|
||||
#include "basic-coro/SingleEvent.hpp"
|
||||
|
||||
class Runtime;
|
||||
|
||||
|
||||
|
||||
class AsyncManager {
|
||||
public:
|
||||
using SockError = bool;
|
||||
using SockFuture = basiccoro::SingleEvent<SockError>;
|
||||
using SockFutureUnique = std::unique_ptr<SockFuture>;
|
||||
using SockFutureMap = std::unordered_multimap<int, SockFutureUnique>;
|
||||
|
||||
private:
|
||||
Runtime& runtime;
|
||||
|
||||
SockFutureMap sockReads;
|
||||
SockFutureMap sockWrites;
|
||||
bool stopping = false;
|
||||
|
||||
static
|
||||
void cleanFutureMap(SockFutureMap&);
|
||||
|
||||
public:
|
||||
AsyncManager(Runtime& runtime) : runtime(runtime) {}
|
||||
AsyncManager(AsyncManager&) = delete;
|
||||
AsyncManager(const AsyncManager&) = delete;
|
||||
AsyncManager(AsyncManager&&) = delete;
|
||||
|
||||
void run();
|
||||
void stop() {
|
||||
stopping = true;
|
||||
}
|
||||
|
||||
basiccoro::AwaitableTask<SockError> waitRead(int fd) {
|
||||
auto event = std::make_unique<SockFuture>();
|
||||
auto eventPtr = event.get();
|
||||
sockReads.emplace(fd, std::move(event));
|
||||
co_return co_await *eventPtr;
|
||||
}
|
||||
basiccoro::AwaitableTask<SockError> waitWrite(int fd) {
|
||||
auto event = std::make_unique<SockFuture>();
|
||||
auto eventPtr = event.get();
|
||||
sockWrites.emplace(fd, std::move(event));
|
||||
co_return co_await *eventPtr;
|
||||
}
|
||||
|
||||
auto& getRuntime() const {
|
||||
return runtime;
|
||||
}
|
||||
};
|
||||
#endif
|
|
@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.5)
|
|||
|
||||
project(llama.any LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
add_executable(llama
|
||||
|
@ -11,7 +11,9 @@ add_executable(llama
|
|||
Socket.hpp
|
||||
Receiver.hpp Receiver.cpp
|
||||
Sender.hpp Sender.cpp
|
||||
AsyncManager.hpp AsyncManager.cpp
|
||||
Runtime.cpp Runtime.hpp
|
||||
basic-coro/AwaitableTask.hpp basic-coro/SingleEvent.hpp basic-coro/SingleEvent.cpp
|
||||
)
|
||||
|
||||
target_compile_definitions(llama PUBLIC PLATFORM="${CMAKE_SYSTEM_NAME}")
|
||||
|
|
20
Client.cpp
20
Client.cpp
|
@ -45,9 +45,9 @@ void Client::fetchAddr(const std::string& addr, unsigned port) {
|
|||
}
|
||||
}
|
||||
|
||||
Client::Client(const std::string& addr, unsigned port) {
|
||||
Client::Client(const std::string& addr, unsigned port, AsyncManager& asyncManager) : aMan(asyncManager) {
|
||||
// Create socket
|
||||
connection = std::make_unique<SocketConnection<Sender::Simple, Receiver::Simple>>(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); //TODO: Care about IPv6
|
||||
connection = std::make_unique<SocketConnection<Sender::Simple, Receiver::Simple>>(aMan, Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); //TODO: Care about IPv6
|
||||
if (*connection < 0) [[unlikely]] {
|
||||
throw Exception("Failed to create TCP socket");
|
||||
}
|
||||
|
@ -72,25 +72,25 @@ Client::Client(const std::string& addr, unsigned port) {
|
|||
# endif
|
||||
}
|
||||
|
||||
void Client::ask(std::string_view prompt, const std::function<void (unsigned progress)>& on_progress, const std::function<void (std::string_view token)>& on_token) {
|
||||
basiccoro::AwaitableTask<void> Client::ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token) {
|
||||
std::string fres;
|
||||
|
||||
// Send prompt length
|
||||
uint8_t len = prompt.length();
|
||||
connection->writeObject(len, true);
|
||||
co_await connection->writeObject(len, true);
|
||||
|
||||
// Send prompt
|
||||
connection->write(prompt);
|
||||
co_await connection->write(prompt);
|
||||
|
||||
// Receive progress
|
||||
for (;;) {
|
||||
uint8_t progress;
|
||||
|
||||
// Receive percentage
|
||||
connection->readObject(progress);
|
||||
co_await connection->readObject(progress);
|
||||
|
||||
// Run on_progress callback
|
||||
on_progress(progress);
|
||||
co_await on_progress(progress);
|
||||
|
||||
// Stop at 100%
|
||||
if (progress == 100) break;
|
||||
|
@ -99,15 +99,15 @@ void Client::ask(std::string_view prompt, const std::function<void (unsigned pro
|
|||
// Receive response
|
||||
for (;;) {
|
||||
// Receive response length
|
||||
connection->readObject(len);
|
||||
co_await connection->readObject(len);
|
||||
|
||||
// End if zero
|
||||
if (len == 0xFF) break;
|
||||
|
||||
// Receive response
|
||||
const auto token = connection->read(len);
|
||||
const auto token = co_await connection->read(len);
|
||||
|
||||
// Run on_token callback
|
||||
on_token(token);
|
||||
co_await on_token(token);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
#ifndef CLIENT_HPP
|
||||
#define CLIENT_HPP
|
||||
#include "Runtime.hpp"
|
||||
#include "AsyncManager.hpp"
|
||||
#include "Socket.hpp"
|
||||
#include "Sender.hpp"
|
||||
#include "Receiver.hpp"
|
||||
#include "basic-coro/AwaitableTask.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
@ -23,6 +25,8 @@ class Client
|
|||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
|
||||
AsyncManager& aMan;
|
||||
|
||||
int fd = -1;
|
||||
# ifdef HAS_ADDRINFO
|
||||
addrinfo
|
||||
|
@ -37,9 +41,9 @@ class Client
|
|||
void fetchAddr(const std::string& addr, unsigned port);
|
||||
|
||||
public:
|
||||
Client(const std::string &addr, unsigned port);
|
||||
Client(const std::string &addr, unsigned port, AsyncManager& asyncManager);
|
||||
|
||||
void ask(std::string_view prompt, const std::function<void (unsigned progress)>& on_progress, const std::function<void (std::string_view token)>& on_token);
|
||||
basiccoro::AwaitableTask<void> ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token);
|
||||
};
|
||||
|
||||
#endif // CLIENT_HPP
|
||||
|
|
43
Receiver.cpp
43
Receiver.cpp
|
@ -4,40 +4,63 @@
|
|||
#include <array>
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
# include <sys/select.h>
|
||||
# include <sys/socket.h>
|
||||
#else
|
||||
# include <ws2tcpip.h>
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
std::string Receiver::Simple::read(size_t amount) {
|
||||
basiccoro::AwaitableTask<std::string> Receiver::Simple::read(size_t amount) {
|
||||
// Create buffer
|
||||
std::string fres;
|
||||
fres.resize(amount);
|
||||
|
||||
// Read into buffer
|
||||
read(reinterpret_cast<std::byte*>(fres.data()), fres.size());
|
||||
co_await read(reinterpret_cast<std::byte*>(fres.data()), fres.size());
|
||||
|
||||
// Return final buffer
|
||||
return fres;
|
||||
co_return fres;
|
||||
}
|
||||
void Receiver::Simple::read(std::byte *buffer, size_t size) {
|
||||
recv(fd, reinterpret_cast<char*>(buffer), size, MSG_WAITALL);
|
||||
basiccoro::AwaitableTask<AsyncManager::SockError> Receiver::Simple::read(std::byte *buffer, size_t size) {
|
||||
size_t allBytesRead = 0;
|
||||
|
||||
while (allBytesRead != size) {
|
||||
// Wait for data
|
||||
if (co_await aMan.waitRead(fd)) [[unlikely]] {
|
||||
// Error
|
||||
co_return true;
|
||||
}
|
||||
|
||||
std::string Receiver::Simple::readSome(size_t max) {
|
||||
// Receive data
|
||||
ssize_t bytesRead;
|
||||
if ((bytesRead = recv(fd, reinterpret_cast<char*>(buffer+allBytesRead), size-allBytesRead, 0)) < 0) [[unlikely]] {
|
||||
// Error
|
||||
co_return true;
|
||||
}
|
||||
allBytesRead += bytesRead;
|
||||
}
|
||||
|
||||
// No error
|
||||
co_return false;
|
||||
}
|
||||
|
||||
basiccoro::AwaitableTask<std::string> Receiver::Simple::readSome(size_t max) {
|
||||
// Create buffer
|
||||
std::string fres;
|
||||
fres.resize(max);
|
||||
|
||||
// Wait for data
|
||||
if (co_await aMan.waitRead(fd)) [[unlikely]] {
|
||||
co_return "";
|
||||
}
|
||||
|
||||
// Receive data
|
||||
ssize_t bytesRead;
|
||||
if ((bytesRead = recv(fd, fres.data(), max, MSG_WAITALL)) < 0) [[unlikely]] {
|
||||
return "";
|
||||
if ((bytesRead = recv(fd, fres.data(), max, 0)) < 0) [[unlikely]] {
|
||||
co_return "";
|
||||
}
|
||||
|
||||
// Resize and return final buffer
|
||||
fres.resize(bytesRead);
|
||||
return fres;
|
||||
co_return fres;
|
||||
}
|
||||
|
|
13
Receiver.hpp
13
Receiver.hpp
|
@ -1,24 +1,29 @@
|
|||
#ifndef _RECEIVER_HPP
|
||||
#define _RECEIVER_HPP
|
||||
#include "Runtime.hpp"
|
||||
#include "AsyncManager.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <cstddef>
|
||||
#include "basic-coro/AwaitableTask.hpp"
|
||||
|
||||
|
||||
namespace Receiver {
|
||||
class Simple {
|
||||
Runtime& runtime;
|
||||
AsyncManager &aMan;
|
||||
|
||||
protected:
|
||||
int fd;
|
||||
|
||||
public:
|
||||
Simple(int fd) : fd(fd) {}
|
||||
Simple(AsyncManager& asyncManager, int fd) : runtime(asyncManager.getRuntime()), aMan(asyncManager), fd(fd) {}
|
||||
|
||||
// Reads the exact amount of bytes given
|
||||
std::string read(size_t amount);
|
||||
void read(std::byte *buffer, size_t size);
|
||||
basiccoro::AwaitableTask<std::string> read(size_t amount);
|
||||
basiccoro::AwaitableTask<AsyncManager::SockError> read(std::byte *buffer, size_t size);
|
||||
// Reads at max. the amount of bytes given
|
||||
std::string readSome(size_t max);
|
||||
basiccoro::AwaitableTask<std::string> readSome(size_t max);
|
||||
|
||||
// Reads an object of type T
|
||||
template<typename T>
|
||||
|
|
17
Sender.cpp
17
Sender.cpp
|
@ -1,9 +1,9 @@
|
|||
#include "Runtime.hpp"
|
||||
#include "Sender.hpp"
|
||||
#include "basic-coro/AwaitableTask.hpp"
|
||||
|
||||
#include <string_view>
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
# include <sys/socket.h>
|
||||
# include <sys/select.h>
|
||||
#else
|
||||
# include <ws2tcpip.h>
|
||||
|
@ -12,13 +12,18 @@
|
|||
|
||||
|
||||
|
||||
void Sender::Simple::write(std::string_view str, bool moreData) {
|
||||
this->write(reinterpret_cast<const std::byte*>(str.data()), str.size(), moreData);
|
||||
basiccoro::AwaitableTask<AsyncManager::SockError> Sender::Simple::write(std::string_view str, bool moreData) {
|
||||
co_return co_await this->write(reinterpret_cast<const std::byte*>(str.data()), str.size(), moreData);
|
||||
}
|
||||
|
||||
void Sender::Simple::write(const std::byte *data, size_t size, bool moreData) {
|
||||
basiccoro::AwaitableTask<AsyncManager::SockError> Sender::Simple::write(const std::byte *data, size_t size, bool moreData) {
|
||||
std::string fres;
|
||||
|
||||
// Write
|
||||
send(fd, reinterpret_cast<const char*>(data), size, MSG_FLAGS_OR_ZERO(MSG_WAITALL | MSG_NOSIGNAL | (int(moreData)*MSG_MORE)));
|
||||
// Wait for socket to get ready for writing
|
||||
if (co_await aMan.waitWrite(fd)) [[unlikely]] {
|
||||
co_return true;
|
||||
}
|
||||
|
||||
// Write
|
||||
co_return send(fd, reinterpret_cast<const char*>(data), size, MSG_FLAGS_OR_ZERO(MSG_NOSIGNAL | (int(moreData)*MSG_MORE))) < 0;
|
||||
}
|
||||
|
|
11
Sender.hpp
11
Sender.hpp
|
@ -1,20 +1,25 @@
|
|||
#ifndef _SENDER_HPP
|
||||
#define _SENDER_HPP
|
||||
#include "AsyncManager.hpp"
|
||||
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <cstddef>
|
||||
#include "basic-coro/AwaitableTask.hpp"
|
||||
|
||||
|
||||
namespace Sender {
|
||||
class Simple {
|
||||
AsyncManager &aMan;
|
||||
|
||||
protected:
|
||||
int fd;
|
||||
|
||||
public:
|
||||
Simple(int fd) : fd(fd) {}
|
||||
Simple(AsyncManager& asyncManager, int fd) : aMan(asyncManager), fd(fd) {}
|
||||
|
||||
void write(std::string_view, bool moreData = false);
|
||||
void write(const std::byte *data, size_t, bool moreData = false);
|
||||
basiccoro::AwaitableTask<AsyncManager::SockError> write(std::string_view, bool moreData = false);
|
||||
basiccoro::AwaitableTask<AsyncManager::SockError> write(const std::byte *data, size_t, bool moreData = false);
|
||||
|
||||
template<typename T>
|
||||
auto writeObject(const T& o, bool moreData = false) {
|
||||
|
|
14
Socket.hpp
14
Socket.hpp
|
@ -1,14 +1,14 @@
|
|||
#ifndef _SOCKET_HPP
|
||||
#define _SOCKET_HPP
|
||||
#include "AsyncManager.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <unistd.h>
|
||||
#if defined(PLATFORM_WINDOWS)
|
||||
# include <ws2tcpip.h>
|
||||
#elif defined(PLATFORM_WII)
|
||||
#include <network.h>
|
||||
#else
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
# include <sys/socket.h>
|
||||
# include <sys/select.h>
|
||||
#else
|
||||
# include <ws2tcpip.h>
|
||||
#endif
|
||||
|
||||
|
||||
|
@ -58,9 +58,9 @@ public:
|
|||
template<class SenderT, class ReceiverT>
|
||||
class SocketConnection : public SenderT, public ReceiverT, public Socket {
|
||||
public:
|
||||
SocketConnection(Socket&& socket)
|
||||
SocketConnection(AsyncManager& asyncManager, Socket&& socket)
|
||||
// Double-initialization seems to yield better assembly
|
||||
: SenderT(socket), ReceiverT(socket), Socket(std::move(socket)) {
|
||||
: SenderT(asyncManager, socket), ReceiverT(asyncManager, socket), Socket(std::move(socket)) {
|
||||
SenderT::fd = get();
|
||||
ReceiverT::fd = get();
|
||||
}
|
||||
|
|
195
basic-coro/AwaitableTask.hpp
Normal file
195
basic-coro/AwaitableTask.hpp
Normal file
|
@ -0,0 +1,195 @@
|
|||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <coroutine>
|
||||
#include <exception>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
|
||||
namespace basiccoro
|
||||
{
|
||||
namespace detail
|
||||
{
|
||||
|
||||
template<class Derived>
|
||||
struct PromiseBase
|
||||
{
|
||||
auto get_return_object() { return std::coroutine_handle<Derived>::from_promise(static_cast<Derived&>(*this)); }
|
||||
void unhandled_exception() { std::terminate(); }
|
||||
};
|
||||
|
||||
template<class Derived, class T> requires std::movable<T> || std::same_as<T, void>
|
||||
struct ValuePromise : public PromiseBase<Derived>
|
||||
{
|
||||
using value_type = T;
|
||||
T val;
|
||||
void return_value(T t) { val = std::move(t); }
|
||||
};
|
||||
|
||||
template<class Derived>
|
||||
struct ValuePromise<Derived, void> : public PromiseBase<Derived>
|
||||
{
|
||||
using value_type = void;
|
||||
void return_void() {}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class AwaitablePromise : public ValuePromise<AwaitablePromise<T>, T>
|
||||
{
|
||||
public:
|
||||
auto initial_suspend() { return std::suspend_never(); }
|
||||
|
||||
auto final_suspend() noexcept
|
||||
{
|
||||
if (waiting_)
|
||||
{
|
||||
waiting_.resume();
|
||||
if (waiting_.done())
|
||||
{
|
||||
waiting_.destroy();
|
||||
}
|
||||
waiting_ = nullptr;
|
||||
}
|
||||
|
||||
return std::suspend_always();
|
||||
}
|
||||
|
||||
void storeWaiting(std::coroutine_handle<> handle)
|
||||
{
|
||||
if (waiting_)
|
||||
{
|
||||
throw std::runtime_error("AwaitablePromise::storeWaiting(): already waiting");
|
||||
}
|
||||
|
||||
waiting_ = handle;
|
||||
}
|
||||
|
||||
~AwaitablePromise()
|
||||
{
|
||||
if (waiting_)
|
||||
{
|
||||
waiting_.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::coroutine_handle<> waiting_ = nullptr;
|
||||
};
|
||||
|
||||
template<class Promise>
|
||||
class TaskBase
|
||||
{
|
||||
public:
|
||||
using promise_type = Promise;
|
||||
|
||||
TaskBase();
|
||||
TaskBase(std::coroutine_handle<promise_type> handle);
|
||||
TaskBase(const TaskBase&) = delete;
|
||||
TaskBase(TaskBase&&);
|
||||
TaskBase& operator=(const TaskBase&) = delete;
|
||||
TaskBase& operator=(TaskBase&&);
|
||||
~TaskBase();
|
||||
|
||||
bool done() const { return handle_.done(); }
|
||||
|
||||
protected:
|
||||
std::coroutine_handle<promise_type> handle_;
|
||||
bool handleShouldBeDestroyed_;
|
||||
};
|
||||
|
||||
template<class Promise>
|
||||
TaskBase<Promise>::TaskBase()
|
||||
: handle_(nullptr), handleShouldBeDestroyed_(false)
|
||||
{}
|
||||
|
||||
template<class Promise>
|
||||
TaskBase<Promise>::TaskBase(std::coroutine_handle<promise_type> handle)
|
||||
: handle_(handle)
|
||||
{
|
||||
// TODO: this whole system needs revamping with something like UniqueCoroutineHandle
|
||||
// and custom static interface to awaiter types - so await_suspend method would take in UniqueCoroutineHandle
|
||||
|
||||
if (handle.done())
|
||||
{
|
||||
// it is resonable to expect that if the coroutine is done before
|
||||
// the task creation, then the original stack is continued without suspending,
|
||||
// and coroutine needs to be destroyed with TaskBase object
|
||||
handleShouldBeDestroyed_ = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// otherwise the coroutine should be managed by object that it is awaiting
|
||||
handleShouldBeDestroyed_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
template<class Promise>
|
||||
TaskBase<Promise>::TaskBase(TaskBase&& other)
|
||||
: handle_(other.handle_), handleShouldBeDestroyed_(std::exchange(other.handleShouldBeDestroyed_, false))
|
||||
{
|
||||
}
|
||||
|
||||
template<class Promise>
|
||||
TaskBase<Promise>& TaskBase<Promise>::operator=(TaskBase&& other)
|
||||
{
|
||||
handle_ = other.handle_;
|
||||
handleShouldBeDestroyed_ = std::exchange(other.handleShouldBeDestroyed_, false);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<class Promise>
|
||||
TaskBase<Promise>::~TaskBase()
|
||||
{
|
||||
if (handleShouldBeDestroyed_)
|
||||
{
|
||||
handle_.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template<class T>
|
||||
class AwaitableTask : public detail::TaskBase<detail::AwaitablePromise<T>>
|
||||
{
|
||||
using Base = detail::TaskBase<detail::AwaitablePromise<T>>;
|
||||
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
class awaiter;
|
||||
friend class awaiter;
|
||||
awaiter operator co_await() const;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
struct AwaitableTask<T>::awaiter
|
||||
{
|
||||
bool await_ready()
|
||||
{
|
||||
return task_.done();
|
||||
}
|
||||
|
||||
template<class Promise>
|
||||
void await_suspend(std::coroutine_handle<Promise> handle)
|
||||
{
|
||||
task_.handle_.promise().storeWaiting(handle);
|
||||
}
|
||||
|
||||
T await_resume()
|
||||
{
|
||||
if constexpr (!std::is_same_v<void, T>)
|
||||
{
|
||||
return std::move(task_.handle_.promise().val);
|
||||
}
|
||||
}
|
||||
|
||||
const AwaitableTask& task_;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
typename AwaitableTask<T>::awaiter AwaitableTask<T>::operator co_await() const
|
||||
{
|
||||
return awaiter{*this};
|
||||
}
|
||||
|
||||
} // namespace basiccoro
|
21
basic-coro/LICENSE
Normal file
21
basic-coro/LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2021 Maksymilian Kadukowski
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
68
basic-coro/README.md
Normal file
68
basic-coro/README.md
Normal file
|
@ -0,0 +1,68 @@
|
|||
# basic-coro - c++ coroutine library
|
||||
Library that implements helper types for using c++ coroutines. Please be aware that this is a training project for me - I wanted to learn more about CMake, gtest and git submodules.
|
||||
|
||||
## Usage
|
||||
### Prerequisites
|
||||
* g++-10
|
||||
|
||||
### Installing
|
||||
```
|
||||
mkdir build && cd build
|
||||
cmake -D CMAKE_CXX_COMPILER=g++-10 ..
|
||||
make install
|
||||
```
|
||||
This will install appropriate headers into `./include/` and static linked library into `./lib/`.
|
||||
|
||||
### Classes
|
||||
Library includes following classes:
|
||||
* `SingleEvent<T>` which models `co_await` enabled event that can be set,
|
||||
* `AwaitableTask<T>` which models `co_await` enabled task.
|
||||
|
||||
Please note that these classes are not multithreading enabled. There is no synchronization or any kind of protection form race conditions. If you need to use coroutines with multithreading, just use [CppCoro](https://github.com/lewissbaker/cppcoro). This library is mostly thought for use with simple GUI programming.
|
||||
|
||||
### Example
|
||||
```c++
|
||||
#include <iostream>
|
||||
|
||||
#include <basiccoro/AwaitableTask.hpp>
|
||||
#include <basiccoro/SingleEvent.hpp>
|
||||
|
||||
basiccoro::AwaitableTask<void> consumer(basiccoro::SingleEvent<int>& event)
|
||||
{
|
||||
std::cout << "consumer: start waiting" << std::endl;
|
||||
|
||||
while (true)
|
||||
{
|
||||
const auto i = co_await event;
|
||||
std::cout << "consumer: received: " << i << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
basiccoro::SingleEvent<int> event;
|
||||
consumer(event);
|
||||
|
||||
while (true)
|
||||
{
|
||||
int i = 0;
|
||||
|
||||
std::cout << "Enter no.(1-9): ";
|
||||
std::cin >> i;
|
||||
|
||||
if (i == 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else if (1 <= i && i <= 9)
|
||||
{
|
||||
event.set(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
Simple example highlighting use of coroutines in producer-consumer problem.
|
||||
|
||||
## Acknowledgments
|
||||
* [CMake C++ Project Template](https://github.com/kigster/cmake-project-template) as this project is based on this template
|
||||
* Lewis Baker has excellent [articles](https://lewissbaker.github.io/) on topic of coroutines and assymetric transfer. This project is mostly based on information (and code snippets) contained in those articles.
|
58
basic-coro/SingleEvent.cpp
Normal file
58
basic-coro/SingleEvent.cpp
Normal file
|
@ -0,0 +1,58 @@
|
|||
#include "SingleEvent.hpp"
|
||||
|
||||
namespace basiccoro
|
||||
{
|
||||
|
||||
detail::SingleEventBase::SingleEventBase(detail::SingleEventBase&& other)
|
||||
: waiting_(std::move(other.waiting_))
|
||||
, isSet_(std::exchange(other.isSet_, false))
|
||||
{
|
||||
}
|
||||
|
||||
detail::SingleEventBase& detail::SingleEventBase::operator=(detail::SingleEventBase&& other)
|
||||
{
|
||||
waiting_ = std::move(other.waiting_);
|
||||
isSet_ = std::exchange(other.isSet_, false);
|
||||
return *this;
|
||||
}
|
||||
|
||||
detail::SingleEventBase::~SingleEventBase()
|
||||
{
|
||||
for (auto handle : waiting_)
|
||||
{
|
||||
handle.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
void detail::SingleEventBase::set_common()
|
||||
{
|
||||
if (!isSet_)
|
||||
{
|
||||
if (waiting_.empty())
|
||||
{
|
||||
isSet_ = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// resuming coroutines can result in
|
||||
// consecutive co_awaits on this object
|
||||
auto temp = std::move(waiting_);
|
||||
|
||||
for (auto handle : temp)
|
||||
{
|
||||
handle.resume();
|
||||
if (handle.done())
|
||||
{
|
||||
handle.destroy();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SingleEvent<void>::awaiter SingleEvent<void>::operator co_await()
|
||||
{
|
||||
return awaiter{*this};
|
||||
}
|
||||
|
||||
} // namespace basiccoro
|
109
basic-coro/SingleEvent.hpp
Normal file
109
basic-coro/SingleEvent.hpp
Normal file
|
@ -0,0 +1,109 @@
|
|||
#pragma once
|
||||
|
||||
#include <coroutine>
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
namespace basiccoro
|
||||
{
|
||||
namespace detail
|
||||
{
|
||||
|
||||
template<class Event>
|
||||
class AwaiterBase
|
||||
{
|
||||
public:
|
||||
AwaiterBase(Event& event)
|
||||
: event_(event)
|
||||
{}
|
||||
|
||||
bool await_ready()
|
||||
{
|
||||
if (event_.isSet())
|
||||
{
|
||||
// unset already set event, then continue coroutine
|
||||
event_.isSet_ = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void await_suspend(std::coroutine_handle<> handle)
|
||||
{
|
||||
event_.waiting_.push_back(handle);
|
||||
}
|
||||
|
||||
typename Event::value_type await_resume()
|
||||
{
|
||||
if constexpr (!std::is_same_v<typename Event::value_type, void>)
|
||||
{
|
||||
if (!event_.result)
|
||||
{
|
||||
throw std::runtime_error("AwaiterBase: no value in event_.result");
|
||||
}
|
||||
return *event_.result;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Event& event_;
|
||||
};
|
||||
|
||||
class SingleEventBase
|
||||
{
|
||||
public:
|
||||
SingleEventBase() = default;
|
||||
SingleEventBase(const SingleEventBase&) = delete;
|
||||
SingleEventBase(SingleEventBase&&);
|
||||
SingleEventBase& operator=(const SingleEventBase&) = delete;
|
||||
SingleEventBase& operator=(SingleEventBase&&);
|
||||
~SingleEventBase();
|
||||
|
||||
bool isSet() const {
|
||||
return isSet_;
|
||||
}
|
||||
|
||||
protected:
|
||||
void set_common();
|
||||
|
||||
private:
|
||||
template<class T>
|
||||
friend class AwaiterBase;
|
||||
std::vector<std::coroutine_handle<>> waiting_;
|
||||
bool isSet_ = false;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template<class T>
|
||||
class SingleEvent : public detail::SingleEventBase
|
||||
{
|
||||
public:
|
||||
using value_type = T;
|
||||
using awaiter = detail::AwaiterBase<SingleEvent<T>>;
|
||||
|
||||
void set(T t) { result = std::move(t); set_common(); }
|
||||
awaiter operator co_await() { return awaiter{*this}; }
|
||||
|
||||
private:
|
||||
friend awaiter;
|
||||
std::optional<T> result;
|
||||
};
|
||||
|
||||
template<>
|
||||
class SingleEvent<void> : public detail::SingleEventBase
|
||||
{
|
||||
public:
|
||||
using value_type = void;
|
||||
using awaiter = detail::AwaiterBase<SingleEvent<void>>;
|
||||
|
||||
void set() {
|
||||
set_common();
|
||||
}
|
||||
awaiter operator co_await();
|
||||
};
|
||||
|
||||
} // namespace basiccoro
|
37
main.cpp
37
main.cpp
|
@ -1,5 +1,7 @@
|
|||
#include "Runtime.hpp"
|
||||
#include "AsyncManager.hpp"
|
||||
#include "Client.hpp"
|
||||
#include "basic-coro/AwaitableTask.hpp"
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
@ -11,19 +13,12 @@ void on_progress(float progress) {
|
|||
std::cout << unsigned(progress) << '\r' << std::flush;
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
Runtime rt;
|
||||
|
||||
// Print header
|
||||
std::cout << "llama.any running on " PLATFORM ".\n"
|
||||
"\n";
|
||||
|
||||
basiccoro::AwaitableTask<void> async_main(Runtime& rt, AsyncManager &aMan) {
|
||||
// Ask for server address
|
||||
const std::string addr = rt.readInput("Server address");
|
||||
|
||||
// Create client
|
||||
Client client(addr, 99181);
|
||||
Client client(addr, 99181, aMan);
|
||||
|
||||
// Connection loop
|
||||
for (;; rt.cooperate()) {
|
||||
|
@ -37,15 +32,31 @@ int main()
|
|||
std::cout << "Prompt: " << prompt << std::endl;
|
||||
|
||||
// Run inference
|
||||
client.ask(prompt, [&rt] (float progress) {
|
||||
co_await client.ask(prompt, [&rt] (float progress) -> basiccoro::AwaitableTask<void> {
|
||||
std::cout << unsigned(progress) << "%\r" << std::flush;
|
||||
rt.cooperate();
|
||||
}, [&rt] (std::string_view token) {
|
||||
co_return;
|
||||
}, [&rt] (std::string_view token) -> basiccoro::AwaitableTask<void> {
|
||||
std::cout << token << std::flush;
|
||||
rt.cooperate();
|
||||
co_return;
|
||||
});
|
||||
std::cout << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
Runtime rt;
|
||||
AsyncManager aMan(rt);
|
||||
|
||||
// Start async main()
|
||||
async_main(rt, aMan);
|
||||
|
||||
// Print header
|
||||
std::cout << "llama.any running on " PLATFORM ".\n"
|
||||
"\n";
|
||||
|
||||
// Start async manager
|
||||
aMan.run();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue