diff --git a/main.cpp b/main.cpp index 64fe3d8..923b23c 100644 --- a/main.cpp +++ b/main.cpp @@ -70,24 +70,54 @@ class Server { // Run inference std::cout << "Running interference...\n" << std::endl; - auto result = inference.run("RUN UNTIL EOS...", [&socket] (const char *token) { + bool aborted = false; + auto result = inference.run("RUN UNTIL EOS...", [&socket, &aborted, stop_soon = false] (const char *token) mutable { uint8_t len; const auto token_len = strlen(token); std::cout << token << std::flush; + // Skip empty tokens + if (token_len == 0) return true; + // Send result length len = token_len; socket.send(const_buffer(&len, sizeof(len))); // Send result socket.send(const_buffer(token, token_len)); + + // Check for cancellation request + socket.receive(mutable_buffer(&len, sizeof(len))); + std::cout << unsigned(len) << std::flush; + switch (len) { + case 0xAB: { // Abort + // Stop immediately + std::cout << "... Aborted." << std::endl; + aborted = true; + return false; + } break; + case 0xCA: { // Cancel + // Stop at end of sentence + stop_soon = true; + } break; + } + + // Check for end of sentence + if (stop_soon && (token[0] == '.' || token[token_len-1] == '.')) { + std::cout << "... Cancelled." << std::endl; + return false; + } + + // Continue return true; }); std::cout << std::endl; - // Send zero-length token - len = 0xFF; - socket.send(const_buffer(&len, sizeof(len))); + // Send zero-length token if not aborted + if (!aborted) { + len = 0xFF; + socket.send(const_buffer(&len, sizeof(len))); + } } }