1
0
Fork 0
mirror of https://gitlab.com/niansa/llama_nds.git synced 2025-03-06 20:53:28 +01:00
llama_nds/Client.cpp
2023-04-07 18:43:33 +02:00

105 lines
3 KiB
C++

#include "Client.hpp"
#include "Socket.hpp"
#include "Sender.hpp"
#include "Receiver.hpp"
#include <cerrno>
#ifndef PLATFORM_WINDOWS
# include <sys/types.h>
# include <sys/socket.h>
# include <netinet/in.h>
# include <arpa/inet.h>
# include <netdb.h>
#else
# include <ws2tcpip.h>
#endif
void Client::fetchAddr(const std::string& addr, unsigned port) {
addrInfo = gethostbyname(addr.c_str());
auto error = errno;
auto bad = addrInfo == nullptr || addrInfo->h_addr_list[0] == nullptr;
// Check for error
if (bad) {
throw Exception("DNS failed to look up hostname: "+std::string(strerror(error))+" ("+addr+')');
}
}
Client::Client(const std::string& addr, unsigned port, AsyncManager& asyncManager) : aMan(asyncManager) {
// Create socket
connection = std::make_unique<SocketConnection<Sender::Simple, Receiver::Simple>>(aMan, Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); //TODO: Care about IPv6
if (*connection < 0) [[unlikely]] {
throw Exception("Failed to create TCP socket");
}
// Fetch address
fetchAddr(addr, port);
// Connect to server
struct sockaddr_in sain;
sain.sin_family = AF_INET;
sain.sin_port = htons(port);
sain.sin_addr.s_addr = *reinterpret_cast<unsigned long *>(addrInfo->h_addr_list[0]);
if (connect(*connection, reinterpret_cast<sockaddr *>(&sain), sizeof(sain)) != 0) [[unlikely]] {
throw Exception("Connection has been declined");
}
}
basiccoro::AwaitableTask<AsyncResult> Client::ask(std::string_view prompt, const std::function<basiccoro::AwaitableTask<void> (unsigned progress)>& on_progress, const std::function<basiccoro::AwaitableTask<void> (std::string_view token)>& on_token) {
std::string fres;
// Send prompt length
uint8_t len = prompt.length();
if (co_await connection->writeObject(len, true) == AsyncResult::Error) {
co_return AsyncResult::Error;
}
// Send prompt
if (co_await connection->write(prompt) == AsyncResult::Error) {
co_return AsyncResult::Error;
}
// Receive progress
for (;;) {
uint8_t progress;
// Receive percentage
if (co_await connection->readObject(progress) == AsyncResult::Error) {
co_return AsyncResult::Error;
}
// Run on_progress callback
co_await on_progress(progress);
// Stop at 100%
if (progress == 100) break;
}
// Receive response
for (;;) {
// Receive response length
if (co_await connection->readObject(len) == AsyncResult::Error) {
co_return AsyncResult::Error;
}
// End if zero
if (len == 0xFF) break;
// Skip empty token
if (len == 0) continue;
// Receive response
const auto token = co_await connection->read(len);
if (token.empty()) {
co_return AsyncResult::Error;
}
// Run on_token callback
co_await on_token(token);
}
// No error
co_return AsyncResult::Success;
}