1
0
Fork 0
mirror of https://gitlab.com/niansa/llama_any_server.git synced 2025-03-06 20:53:35 +01:00

Implemented cancellation

This commit is contained in:
niansa 2023-04-08 14:21:54 +02:00
parent a583fadefb
commit af81033f90

View file

@ -70,24 +70,54 @@ class Server {
// Run inference // Run inference
std::cout << "Running interference...\n" << std::endl; std::cout << "Running interference...\n" << std::endl;
auto result = inference.run("RUN UNTIL EOS...", [&socket] (const char *token) { bool aborted = false;
auto result = inference.run("RUN UNTIL EOS...", [&socket, &aborted, stop_soon = false] (const char *token) mutable {
uint8_t len; uint8_t len;
const auto token_len = strlen(token); const auto token_len = strlen(token);
std::cout << token << std::flush; std::cout << token << std::flush;
// Skip empty tokens
if (token_len == 0) return true;
// Send result length // Send result length
len = token_len; len = token_len;
socket.send(const_buffer(&len, sizeof(len))); socket.send(const_buffer(&len, sizeof(len)));
// Send result // Send result
socket.send(const_buffer(token, token_len)); socket.send(const_buffer(token, token_len));
// Check for cancellation request
socket.receive(mutable_buffer(&len, sizeof(len)));
std::cout << unsigned(len) << std::flush;
switch (len) {
case 0xAB: { // Abort
// Stop immediately
std::cout << "... Aborted." << std::endl;
aborted = true;
return false;
} break;
case 0xCA: { // Cancel
// Stop at end of sentence
stop_soon = true;
} break;
}
// Check for end of sentence
if (stop_soon && (token[0] == '.' || token[token_len-1] == '.')) {
std::cout << "... Cancelled." << std::endl;
return false;
}
// Continue
return true; return true;
}); });
std::cout << std::endl; std::cout << std::endl;
// Send zero-length token // Send zero-length token if not aborted
len = 0xFF; if (!aborted) {
socket.send(const_buffer(&len, sizeof(len))); len = 0xFF;
socket.send(const_buffer(&len, sizeof(len)));
}
} }
} }