mirror of
https://gitlab.com/niansa/llama_any_server.git
synced 2025-03-06 20:53:35 +01:00
Implemented cancellation
This commit is contained in:
parent
a583fadefb
commit
af81033f90
1 changed files with 34 additions and 4 deletions
34
main.cpp
34
main.cpp
|
@ -70,26 +70,56 @@ class Server {
|
|||
|
||||
// Run inference
|
||||
std::cout << "Running interference...\n" << std::endl;
|
||||
auto result = inference.run("RUN UNTIL EOS...", [&socket] (const char *token) {
|
||||
bool aborted = false;
|
||||
auto result = inference.run("RUN UNTIL EOS...", [&socket, &aborted, stop_soon = false] (const char *token) mutable {
|
||||
uint8_t len;
|
||||
const auto token_len = strlen(token);
|
||||
std::cout << token << std::flush;
|
||||
|
||||
// Skip empty tokens
|
||||
if (token_len == 0) return true;
|
||||
|
||||
// Send result length
|
||||
len = token_len;
|
||||
socket.send(const_buffer(&len, sizeof(len)));
|
||||
|
||||
// Send result
|
||||
socket.send(const_buffer(token, token_len));
|
||||
|
||||
// Check for cancellation request
|
||||
socket.receive(mutable_buffer(&len, sizeof(len)));
|
||||
std::cout << unsigned(len) << std::flush;
|
||||
switch (len) {
|
||||
case 0xAB: { // Abort
|
||||
// Stop immediately
|
||||
std::cout << "... Aborted." << std::endl;
|
||||
aborted = true;
|
||||
return false;
|
||||
} break;
|
||||
case 0xCA: { // Cancel
|
||||
// Stop at end of sentence
|
||||
stop_soon = true;
|
||||
} break;
|
||||
}
|
||||
|
||||
// Check for end of sentence
|
||||
if (stop_soon && (token[0] == '.' || token[token_len-1] == '.')) {
|
||||
std::cout << "... Cancelled." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Continue
|
||||
return true;
|
||||
});
|
||||
std::cout << std::endl;
|
||||
|
||||
// Send zero-length token
|
||||
// Send zero-length token if not aborted
|
||||
if (!aborted) {
|
||||
len = 0xFF;
|
||||
socket.send(const_buffer(&len, sizeof(len)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Server() : endpoint(ip::tcp::v4(), (unsigned short)99181/*TODO: do not hardcode port*/), acceptor(service, endpoint) {}
|
||||
|
|
Loading…
Add table
Reference in a new issue