mirror of
https://gitlab.com/niansa/llama_nds.git
synced 2025-03-06 20:53:28 +01:00
111 lines
3.3 KiB
C++
111 lines
3.3 KiB
C++
#include "Client.hpp"
|
|
#include "Socket.hpp"
|
|
#include "Sender.hpp"
|
|
#include "Receiver.hpp"
|
|
|
|
#include <cerrno>
|
|
#ifndef PLATFORM_WINDOWS
|
|
# include <sys/types.h>
|
|
# include <sys/socket.h>
|
|
# include <netinet/in.h>
|
|
# include <arpa/inet.h>
|
|
# include <netdb.h>
|
|
#else
|
|
# include <ws2tcpip.h>
|
|
#endif
|
|
|
|
|
|
|
|
void Client::fetchAddr(const std::string& addr, unsigned port) {
|
|
addrInfo = gethostbyname(addr.c_str());
|
|
auto error = errno;
|
|
auto bad = addrInfo == nullptr || addrInfo->h_addr_list[0] == nullptr;
|
|
|
|
// Check for error
|
|
if (bad) {
|
|
throw Exception("DNS failed to look up hostname: "+std::string(strerror(error))+" ("+addr+')');
|
|
}
|
|
}
|
|
|
|
Client::Client(const std::string& addr, unsigned port, AsyncManager& asyncManager) : aMan(asyncManager) {
|
|
// Create socket
|
|
connection = std::make_unique<SocketConnection<Sender::Simple, Receiver::Simple>>(aMan, Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); //TODO: Care about IPv6
|
|
if (*connection < 0) [[unlikely]] {
|
|
throw Exception("Failed to create TCP socket");
|
|
}
|
|
|
|
// Fetch address
|
|
fetchAddr(addr, port);
|
|
|
|
// Connect to server
|
|
struct sockaddr_in sain;
|
|
sain.sin_family = AF_INET;
|
|
sain.sin_port = htons(port);
|
|
sain.sin_addr.s_addr = *reinterpret_cast<unsigned long *>(addrInfo->h_addr_list[0]);
|
|
if (connect(*connection, reinterpret_cast<sockaddr *>(&sain), sizeof(sain)) != 0) [[unlikely]] {
|
|
throw Exception("Connection has been declined");
|
|
}
|
|
}
|
|
|
|
basiccoro::AwaitableTask<AsyncResult> Client::ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token) {
|
|
// Send prompt length
|
|
uint8_t len = prompt.length();
|
|
if (co_await connection->writeObject(len, true) == AsyncResult::Error) {
|
|
co_return AsyncResult::Error;
|
|
}
|
|
|
|
// Send prompt
|
|
if (co_await connection->write(prompt) == AsyncResult::Error) {
|
|
co_return AsyncResult::Error;
|
|
}
|
|
|
|
// Receive progress
|
|
for (;;) {
|
|
uint8_t progress;
|
|
|
|
// Receive percentage
|
|
if (co_await connection->readObject(progress) == AsyncResult::Error) {
|
|
co_return AsyncResult::Error;
|
|
}
|
|
|
|
// Run on_progress callback
|
|
co_await on_progress(progress);
|
|
|
|
// Stop at 100%
|
|
if (progress == 100) break;
|
|
}
|
|
|
|
// Receive response
|
|
for (;;) {
|
|
// Receive response length
|
|
if (co_await connection->readObject(len) == AsyncResult::Error) {
|
|
co_return AsyncResult::Error;
|
|
}
|
|
|
|
// End if zero
|
|
if (len == 0xFF) break;
|
|
|
|
// Receive response
|
|
const auto token = co_await connection->read(len);
|
|
if (token.empty()) {
|
|
co_return AsyncResult::Error;
|
|
}
|
|
|
|
// Run on_token callback
|
|
co_await on_token(token);
|
|
|
|
// Pass cancellation request
|
|
int keys = keysCurrent();
|
|
if (keys & KEY_START) {
|
|
co_await connection->writeObject<uint8_t>(0xAB); // Abort
|
|
co_return AsyncResult::Success;
|
|
} else if (keys & KEY_SELECT) {
|
|
co_await connection->writeObject<uint8_t>(0xCA); // Cancel
|
|
} else {
|
|
co_await connection->writeObject<uint8_t>(0xC0); // Continue
|
|
}
|
|
}
|
|
|
|
// No error
|
|
co_return AsyncResult::Success;
|
|
}
|