From ea861bcac9d233446e9a19308c026a88b1e264c0 Mon Sep 17 00:00:00 2001 From: niansa Date: Mon, 3 Apr 2023 22:03:25 +0200 Subject: [PATCH] Initial commit --- .gitmodules | 3 ++ CMakeLists.txt | 4 ++ libjustlm | 1 + main.cpp | 121 +++++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 124 insertions(+), 5 deletions(-) create mode 100644 .gitmodules create mode 160000 libjustlm diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..ac21a73 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "libjustlm"] + path = libjustlm + url = https://gitlab.com/niansa/libjustlm.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 280b5a1..d0fc85f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,11 @@ project(llama_any_server LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) +add_subdirectory(libjustlm) + add_executable(llama_any_server main.cpp) +target_link_libraries(llama_any_server PUBLIC libjustlm) + install(TARGETS llama_any_server LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/libjustlm b/libjustlm new file mode 160000 index 0000000..8b5a375 --- /dev/null +++ b/libjustlm @@ -0,0 +1 @@ +Subproject commit 8b5a375f592099ef12d3988087fe699b3ecba3bb diff --git a/main.cpp b/main.cpp index 3129fb9..f047c09 100644 --- a/main.cpp +++ b/main.cpp @@ -1,9 +1,120 @@ #include +#include +#include -using namespace std; -int main() -{ - cout << "Hello World!" << endl; - return 0; + +namespace Application { +using namespace boost::asio; + +class Server { + io_context service; + ip::tcp::endpoint endpoint; + ip::tcp::acceptor acceptor; + + static inline + const LM::Inference::Params& get_params() noexcept { + static auto params = [] () { + LM::Inference::Params params; + params.n_batch = 8; + params.n_ctx = 2048; + params.n_repeat_last = 64; + params.repeat_penalty = 1.3f; + params.temp = 0.1f; + params.top_k = 40; + params.top_p = 0.95f; + params.use_mlock = false; + return params; + }(); + return params; + } + + void client_run(ip::tcp::socket& socket) { + uint8_t len; + + // Create inference instance + LM::Inference inference("gpt4all-lora-unfiltered-quantized.bin"/*TODO: do not hardcode path*/, get_params()); + + for (bool first_run = true; ; first_run = false) { + // Receive prompt length + std::cout << "Receiving prompt length..." << std::endl; + socket.receive(mutable_buffer(&len, sizeof(len))); + + // Receive prompt + std::string prompt; + prompt.resize(len); + std::cout << "Receiving prompt of length " << unsigned(len) << "..." << std::endl; + socket.receive(mutable_buffer(prompt.data(), len)); + + // Stop on zero length + if (len == 0) break; + + // Append prompt + std::cout << "Evaluating prompt..." << std::endl; + uint8_t old_progress = 0; + inference.append(std::string(first_run?"Below is an instruction that describes a task. Write a response that appropriately completes the request.":"")+"\n\n### Instruction:\n\n"+prompt+"\n\n### Response:\n\n", [&old_progress, &socket] (float progress) { + uint8_t progress_i = progress; + // Report new progress + if (old_progress != progress_i) { + socket.send(const_buffer(&progress_i, sizeof(progress_i))); + // Set as old progress + old_progress = progress_i; + } + return true; + }); + // Report completion if needed + if (old_progress != 100) { + old_progress = 100; + socket.send(const_buffer(&old_progress, sizeof(old_progress))); + } + + // Run inference + std::cout << "Running interference...\n" << std::endl; + auto result = inference.run("\n", [&socket] (const char *token) { + uint8_t len; + const auto token_len = strlen(token); + std::cout << token << std::flush; + + // Send result length + len = token_len; + socket.send(const_buffer(&len, sizeof(len))); + + // Send result + socket.send(const_buffer(token, token_len)); + return true; + }); + std::cout << std::endl; + + // Send zero-length token + len = 0xFF; + socket.send(const_buffer(&len, sizeof(len))); + } + } + +public: + Server() : endpoint(ip::tcp::v4(), (unsigned short)99181/*TODO: do not hardcode port*/), acceptor(service, endpoint) {} + + void run() { + std::cout << "Waiting for connection..." << std::endl; + + // Wait for connections infinitely + for (;;) { + // Accept connection immediately + ip::tcp::socket socket(service); + acceptor.accept(socket); + + // Start thread for new connection + std::cout << "Accepted connection, starting connection thread..." << std::endl; + std::thread([this, socket = std::move(socket)] () mutable { + client_run(socket); + }).detach(); + } + } +}; +} + + +int main() { + Application::Server server; + server.run(); }