1
0
Fork 0
mirror of https://gitlab.com/niansa/llama_any_server.git synced 2025-03-06 20:53:35 +01:00

Implemented cancellation

This commit is contained in:
niansa 2023-04-08 14:21:54 +02:00
parent a583fadefb
commit af81033f90

View file

@ -70,26 +70,56 @@ class Server {
// Run inference
std::cout << "Running interference...\n" << std::endl;
auto result = inference.run("RUN UNTIL EOS...", [&socket] (const char *token) {
bool aborted = false;
auto result = inference.run("RUN UNTIL EOS...", [&socket, &aborted, stop_soon = false] (const char *token) mutable {
uint8_t len;
const auto token_len = strlen(token);
std::cout << token << std::flush;
// Skip empty tokens
if (token_len == 0) return true;
// Send result length
len = token_len;
socket.send(const_buffer(&len, sizeof(len)));
// Send result
socket.send(const_buffer(token, token_len));
// Check for cancellation request
socket.receive(mutable_buffer(&len, sizeof(len)));
std::cout << unsigned(len) << std::flush;
switch (len) {
case 0xAB: { // Abort
// Stop immediately
std::cout << "... Aborted." << std::endl;
aborted = true;
return false;
} break;
case 0xCA: { // Cancel
// Stop at end of sentence
stop_soon = true;
} break;
}
// Check for end of sentence
if (stop_soon && (token[0] == '.' || token[token_len-1] == '.')) {
std::cout << "... Cancelled." << std::endl;
return false;
}
// Continue
return true;
});
std::cout << std::endl;
// Send zero-length token
// Send zero-length token if not aborted
if (!aborted) {
len = 0xFF;
socket.send(const_buffer(&len, sizeof(len)));
}
}
}
public:
Server() : endpoint(ip::tcp::v4(), (unsigned short)99181/*TODO: do not hardcode port*/), acceptor(service, endpoint) {}