#include #include #include namespace Application { using namespace boost::asio; class Server { io_context service; ip::tcp::endpoint endpoint; ip::tcp::acceptor acceptor; static inline const LM::Inference::Params& get_params() noexcept { static auto params = [] () { LM::Inference::Params params; params.n_batch = 8; params.n_ctx = 2048; params.n_repeat_last = 64; params.repeat_penalty = 1.3f; params.temp = 0.1f; params.top_k = 40; params.top_p = 0.95f; params.use_mlock = false; return params; }(); return params; } void client_run(ip::tcp::socket& socket) { uint8_t len; // Create inference instance LM::Inference inference("gpt4all-lora-unfiltered-quantized.bin"/*TODO: do not hardcode path*/, get_params()); for (bool first_run = true; ; first_run = false) { // Receive prompt length std::cout << "Receiving prompt length..." << std::endl; socket.receive(mutable_buffer(&len, sizeof(len))); // Receive prompt std::string prompt; prompt.resize(len); std::cout << "Receiving prompt of length " << unsigned(len) << "..." << std::endl; socket.receive(mutable_buffer(prompt.data(), len)); // Stop on zero length if (len == 0) break; // Append prompt std::cout << "Evaluating prompt..." << std::endl; uint8_t old_progress = 0; inference.append(std::string(first_run?"Below is an instruction that describes a task. Write a response that appropriately completes the request.":"")+"\n\n### Instruction:\n\n"+prompt+"\n\n### Response:\n\n", [&old_progress, &socket] (float progress) { uint8_t progress_i = progress; // Report new progress if (old_progress != progress_i) { socket.send(const_buffer(&progress_i, sizeof(progress_i))); // Set as old progress old_progress = progress_i; } return true; }); // Report completion if needed if (old_progress != 100) { old_progress = 100; socket.send(const_buffer(&old_progress, sizeof(old_progress))); } // Run inference std::cout << "Running interference...\n" << std::endl; bool aborted = false; auto result = inference.run("RUN UNTIL EOS...", [&socket, &aborted, stop_soon = false] (const char *token) mutable { uint8_t len; const auto token_len = strlen(token); std::cout << token << std::flush; // Skip empty tokens if (token_len == 0) return true; // Send result length len = token_len; socket.send(const_buffer(&len, sizeof(len))); // Send result socket.send(const_buffer(token, token_len)); // Check for cancellation request socket.receive(mutable_buffer(&len, sizeof(len))); std::cout << unsigned(len) << std::flush; switch (len) { case 0xAB: { // Abort // Stop immediately std::cout << "... Aborted." << std::endl; aborted = true; return false; } break; case 0xCA: { // Cancel // Stop at end of sentence stop_soon = true; } break; } // Check for end of sentence if (stop_soon && (token[0] == '.' || token[token_len-1] == '.')) { std::cout << "... Cancelled." << std::endl; return false; } // Continue return true; }); std::cout << std::endl; // Send zero-length token if not aborted if (!aborted) { len = 0xFF; socket.send(const_buffer(&len, sizeof(len))); } } } public: Server() : endpoint(ip::tcp::v4(), (unsigned short)99181/*TODO: do not hardcode port*/), acceptor(service, endpoint) {} void run() { std::cout << "Waiting for connection..." << std::endl; // Wait for connections infinitely for (;;) { // Accept connection immediately ip::tcp::socket socket(service); acceptor.accept(socket); // Start thread for new connection std::cout << "Accepted connection, starting connection thread..." << std::endl; std::thread([this, socket = std::move(socket)] () mutable { client_run(socket); }).detach(); } } }; } int main() { Application::Server server; server.run(); }